Skip to content

API Reference

tnh_scholar

TNH Scholar: Text Processing and Analysis Tools

TNH Scholar is an AI-driven project designed to explore, query, process and translate the teachings of Thich Nhat Hanh and other Plum Village Dharma Teachers. The project
aims to create a resource for practitioners and scholars to deeply engage with
mindfulness and spiritual wisdom through natural language processing and machine
learning models.

Core Features
  • Audio transcription and processing
  • Multi-lingual text processing and translation
  • Pattern-based text analysis
  • OCR processing for historical documents
  • CLI tools for batch processing
Package Structure
  • tnh_scholar/
  • CLI_tools/ - Command line interface tools
  • audio_processing/ - Audio file handling and transcription
  • journal_processing/ - Journal and publication processing
  • ocr_processing/ - Optical character recognition tools
  • text_processing/ - Core text processing utilities
  • video_processing/ - Video file handling and transcription
  • utils/ - Shared utility functions
  • xml_processing/ - XML parsing and generation
Environment Configuration
  • The package uses environment variables for configuration, including:
  • TNH_PATTERN_DIR - Directory for text processing patterns
  • OPENAI_API_KEY - OpenAI API authentication
  • GOOGLE_VISION_KEY - Google Cloud Vision API key for OCR
CLI Tools
  • audio-transcribe - Audio file transcription utility
  • tnh-fab - Text processing and analysis toolkit

For more information, see: - Documentation: https://aaronksolomon.github.io/tnh-scholar/ - Source: https://github.com/aaronksolomon/tnh-scholar - Issues: https://github.com/aaronksolomon/tnh-scholar/issues

Dependencies
  • Core: click, pydantic, openai, yt-dlp
  • Optional: streamlit (GUI), spacy (NLP), google-cloud-vision (OCR)

TNH_CLI_TOOLS_DIR = TNH_ROOT_SRC_DIR / 'cli_tools' module-attribute

TNH_CONFIG_DIR = Path.home() / '.config' / 'tnh-scholar' module-attribute

TNH_DEFAULT_PATTERN_DIR = TNH_PROJECT_ROOT_DIR / 'patterns' module-attribute

TNH_LOG_DIR = TNH_CONFIG_DIR / 'logs' module-attribute

TNH_METADATA_PROCESS_FIELD = 'tnh_processing' module-attribute

TNH_PROJECT_ROOT_DIR = TNH_ROOT_SRC_DIR.resolve().parent.parent module-attribute

TNH_ROOT_SRC_DIR = Path(__file__).resolve().parent module-attribute

__version__ = '0.1.3' module-attribute

ai_text_processing

Public surface for tnh_scholar.ai_text_processing.

Historically this module eagerly imported multiple submodules with heavy dependencies (e.g., audio codecs, ML toolkits) which made importing lightweight components such as Prompt surprisingly expensive and brittle in test environments. We now lazily import the concrete implementations on demand so that callers can depend on just the pieces they need.

__all__ = ['OpenAIProcessor', 'SectionParser', 'SectionProcessor', 'find_sections', 'process_text', 'process_text_by_paragraphs', 'process_text_by_sections', 'get_pattern', 'translate_text_by_lines', 'openai_process_text', 'GitBackedRepository', 'LocalPromptManager', 'Prompt', 'PromptCatalog', 'AIResponse', 'LogicalSection', 'SectionEntry', 'TextObject', 'TextObjectInfo'] module-attribute

AIResponse

Bases: BaseModel

Class for dividing large texts into AI-processable segments while maintaining broader document context.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class AIResponse(BaseModel):
    """Class for dividing large texts into AI-processable segments while
    maintaining broader document context."""
    document_summary: str = Field(
        ...,
        description="Concise, comprehensive overview of the text's content and purpose"
    )
    document_metadata: str = Field(
        ...,
        description="Available Dublin Core standard metadata in human-readable YAML format" # noqa: E501
    )
    key_concepts: str = Field(
        ...,
        description="Important terms, ideas, or references that appear throughout the text"  # noqa: E501
    )
    narrative_context: str = Field(
        ...,
        description="Concise overview of how the text develops or progresses as a whole"
    )
    language: str = Field(..., description="ISO 639-1 language code")
    sections: List[LogicalSection]
document_metadata = Field(..., description='Available Dublin Core standard metadata in human-readable YAML format') class-attribute instance-attribute
document_summary = Field(..., description="Concise, comprehensive overview of the text's content and purpose") class-attribute instance-attribute
key_concepts = Field(..., description='Important terms, ideas, or references that appear throughout the text') class-attribute instance-attribute
language = Field(..., description='ISO 639-1 language code') class-attribute instance-attribute
narrative_context = Field(..., description='Concise overview of how the text develops or progresses as a whole') class-attribute instance-attribute
sections instance-attribute

GitBackedRepository

Manages versioned storage of prompts using Git.

Provides basic Git operations while hiding complexity: - Automatic versioning of changes - Basic conflict resolution - History tracking

Source code in src/tnh_scholar/ai_text_processing/prompts.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
class GitBackedRepository:
    """
    Manages versioned storage of prompts using Git.

    Provides basic Git operations while hiding complexity:
    - Automatic versioning of changes
    - Basic conflict resolution
    - History tracking
    """

    def __init__(self, repo_path: Path):
        """
        Initialize or connect to Git repository.

        Args:
            repo_path: Path to repository directory

        Raises:
            GitCommandError: If Git operations fail
        """
        self.repo_path = repo_path

        try:
            # Try to connect to existing repository
            self.repo = Repo(repo_path)
            logger.debug(f"Connected to existing Git repository at {repo_path}")

        except InvalidGitRepositoryError:
            # Initialize new repository if none exists
            logger.info(f"Initializing new Git repository at {repo_path}")
            self.repo = Repo.init(repo_path)

            # Create initial commit if repo is empty
            if not self.repo.head.is_valid():
                # Create and commit .gitignore
                gitignore = repo_path / ".gitignore"
                gitignore.write_text("*.lock\n.DS_Store\n")
                self.repo.index.add([".gitignore"])
                self.repo.index.commit("Initial repository setup")

    def update_file(self, file_path: Path) -> str:
        """
        Stage and commit changes to a file in the Git repository.

        Args:
            file_path: Absolute or relative path to the file.

        Returns:
            str: Commit hash if changes were made.

        Raises:
            FileNotFoundError: If the file does not exist.
            ValueError: If the file is outside the repository.
            GitCommandError: If Git operations fail.
        """
        file_path = file_path.resolve()

        # Ensure the file is within the repository
        try:
            rel_path = file_path.relative_to(self.repo_path)
        except ValueError as e:
            raise ValueError(
                f"File {file_path} is not under the repository root {self.repo_path}"
            ) from e

        if not file_path.exists():
            raise FileNotFoundError(f"File does not exist: {file_path}")

        try:
            return self._commit_file_update(rel_path, file_path)
        except GitCommandError as e:
            logger.error(f"Git operation failed: {e}")
            raise

    def _commit_file_update(self, rel_path, file_path):
        if self._is_file_clean(rel_path):
            # Return the current commit hash if no changes
            return self.repo.head.commit.hexsha

        logger.info(f"Detected changes in {rel_path}, updating version control.")
        self.repo.index.add([str(rel_path)])
        commit = self.repo.index.commit(
            f"{MANAGER_UPDATE_MESSAGE} {rel_path.stem}",
            author=Actor("PromptManager", ""),
        )
        logger.info(f"Committed changes to {file_path}: {commit.hexsha}")
        return commit.hexsha

    def _get_file_revisions(self, file_path: Path) -> List[Commit]:
        """
        Get ordered list of commits that modified a file, most recent first.

        Args:
            file_path: Path to file relative to repository root

        Returns:
            List of Commit objects affecting this file

        Raises:
            GitCommandError: If Git operations fail
        """
        rel_path = file_path.relative_to(self.repo_path)
        try:
            return list(self.repo.iter_commits(paths=str(rel_path)))
        except GitCommandError as e:
            logger.error(f"Failed to get commits for {rel_path}: {e}")
            return []

    def _get_commit_diff(
        self, commit: Commit, file_path: Path, prev_commit: Optional[Commit] = None
    ) -> Tuple[str, str]:
        """
        Get both stat and detailed diff for a commit.

        Args:
            commit: Commit to diff
            file_path: Path relative to repository root
            prev_commit: Previous commit for diff, defaults to commit's parent

        Returns:
            Tuple of (stat_diff, detailed_diff) where:
                stat_diff: Summary of changes (files changed, insertions/deletions)
                detailed_diff: Colored word-level diff with context

        Raises:
            GitCommandError: If Git operations fail
        """
        prev_hash = prev_commit.hexsha if prev_commit else f"{commit.hexsha}^"
        rel_path = file_path.relative_to(self.repo_path)

        try:
            # Get stats diff
            stat = self.repo.git.diff(prev_hash, commit.hexsha, rel_path, stat=True)

            # Get detailed diff
            diff = self.repo.git.diff(
                prev_hash,
                commit.hexsha,
                rel_path,
                unified=2,
                word_diff="plain",
                color="always",
                ignore_space_change=True,
            )

            return stat, diff
        except GitCommandError as e:
            logger.error(f"Failed to get diff for {commit.hexsha}: {e}")
            return "", ""

    def display_history(self, file_path: Path, max_versions: int = 0) -> None:
        """
        Display history of changes for a file with diffs between versions.

        Shows most recent changes first, limited to max_versions entries.
        For each change shows:
        - Commit info and date
        - Stats summary of changes
        - Detailed color diff with 2 lines of context

        Args:
            file_path: Path to file in repository
            max_versions: Maximum number of versions to show, 
            if zero, shows all revisions.

        Example:
            >>> repo.display_history(Path("prompts/format_dharma_talk.yaml"))
            Commit abc123def (2024-12-28 14:30:22):
            1 file changed, 5 insertions(+), 2 deletions(-)

            diff --git a/prompts/format_dharma_talk.yaml ...
            ...
        """

        try:
            # Get commit history
            commits = self._get_file_revisions(file_path)
            if not commits:
                print(f"No history found for {file_path}")
                return

            if max_versions == 0:
                max_versions = len(commits)  # look at all commits.

            # Display limited history with diffs
            for i, commit in enumerate(commits[:max_versions]):
                # Print commit header
                date_str = commit.committed_datetime.strftime("%Y-%m-%d %H:%M:%S")
                print(f"\nCommit {commit.hexsha[:8]} ({date_str}):")
                print(f"Message: {commit.message.strip()}")

                # Get and display diffs
                prev_commit = commits[i + 1] if i + 1 < len(commits) else None
                stat_diff, detailed_diff = self._get_commit_diff(
                    commit, file_path, prev_commit
                )

                if stat_diff:
                    print("\nChanges:")
                    print(stat_diff)
                if detailed_diff:
                    print("\nDetailed diff:")
                    print(detailed_diff)

                print("\033[0m", end="")
                print("-" * 80)  # Visual separator between commits

        except Exception as e:
            logger.error(f"Failed to display history for {file_path}: {e}")
            print(f"Error displaying history: {e}")
            raise

    def _is_file_clean(self, rel_path: Path) -> bool:
        """
        Check if file has uncommitted changes.

        Args:
            rel_path: Path relative to repository root

        Returns:
            bool: True if file has no changes
        """
        return str(rel_path) not in (
            [item.a_path for item in self.repo.index.diff(None)]
            + self.repo.untracked_files
        )
repo = Repo(repo_path) instance-attribute
repo_path = repo_path instance-attribute
__init__(repo_path)

Initialize or connect to Git repository.

Parameters:

Name Type Description Default
repo_path Path

Path to repository directory

required

Raises:

Type Description
GitCommandError

If Git operations fail

Source code in src/tnh_scholar/ai_text_processing/prompts.py
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
def __init__(self, repo_path: Path):
    """
    Initialize or connect to Git repository.

    Args:
        repo_path: Path to repository directory

    Raises:
        GitCommandError: If Git operations fail
    """
    self.repo_path = repo_path

    try:
        # Try to connect to existing repository
        self.repo = Repo(repo_path)
        logger.debug(f"Connected to existing Git repository at {repo_path}")

    except InvalidGitRepositoryError:
        # Initialize new repository if none exists
        logger.info(f"Initializing new Git repository at {repo_path}")
        self.repo = Repo.init(repo_path)

        # Create initial commit if repo is empty
        if not self.repo.head.is_valid():
            # Create and commit .gitignore
            gitignore = repo_path / ".gitignore"
            gitignore.write_text("*.lock\n.DS_Store\n")
            self.repo.index.add([".gitignore"])
            self.repo.index.commit("Initial repository setup")
display_history(file_path, max_versions=0)

Display history of changes for a file with diffs between versions.

Shows most recent changes first, limited to max_versions entries. For each change shows: - Commit info and date - Stats summary of changes - Detailed color diff with 2 lines of context

Parameters:

Name Type Description Default
file_path Path

Path to file in repository

required
max_versions int

Maximum number of versions to show,

0
Example

repo.display_history(Path("prompts/format_dharma_talk.yaml")) Commit abc123def (2024-12-28 14:30:22): 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/prompts/format_dharma_talk.yaml ... ...

Source code in src/tnh_scholar/ai_text_processing/prompts.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
def display_history(self, file_path: Path, max_versions: int = 0) -> None:
    """
    Display history of changes for a file with diffs between versions.

    Shows most recent changes first, limited to max_versions entries.
    For each change shows:
    - Commit info and date
    - Stats summary of changes
    - Detailed color diff with 2 lines of context

    Args:
        file_path: Path to file in repository
        max_versions: Maximum number of versions to show, 
        if zero, shows all revisions.

    Example:
        >>> repo.display_history(Path("prompts/format_dharma_talk.yaml"))
        Commit abc123def (2024-12-28 14:30:22):
        1 file changed, 5 insertions(+), 2 deletions(-)

        diff --git a/prompts/format_dharma_talk.yaml ...
        ...
    """

    try:
        # Get commit history
        commits = self._get_file_revisions(file_path)
        if not commits:
            print(f"No history found for {file_path}")
            return

        if max_versions == 0:
            max_versions = len(commits)  # look at all commits.

        # Display limited history with diffs
        for i, commit in enumerate(commits[:max_versions]):
            # Print commit header
            date_str = commit.committed_datetime.strftime("%Y-%m-%d %H:%M:%S")
            print(f"\nCommit {commit.hexsha[:8]} ({date_str}):")
            print(f"Message: {commit.message.strip()}")

            # Get and display diffs
            prev_commit = commits[i + 1] if i + 1 < len(commits) else None
            stat_diff, detailed_diff = self._get_commit_diff(
                commit, file_path, prev_commit
            )

            if stat_diff:
                print("\nChanges:")
                print(stat_diff)
            if detailed_diff:
                print("\nDetailed diff:")
                print(detailed_diff)

            print("\033[0m", end="")
            print("-" * 80)  # Visual separator between commits

    except Exception as e:
        logger.error(f"Failed to display history for {file_path}: {e}")
        print(f"Error displaying history: {e}")
        raise
update_file(file_path)

Stage and commit changes to a file in the Git repository.

Parameters:

Name Type Description Default
file_path Path

Absolute or relative path to the file.

required

Returns:

Name Type Description
str str

Commit hash if changes were made.

Raises:

Type Description
FileNotFoundError

If the file does not exist.

ValueError

If the file is outside the repository.

GitCommandError

If Git operations fail.

Source code in src/tnh_scholar/ai_text_processing/prompts.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def update_file(self, file_path: Path) -> str:
    """
    Stage and commit changes to a file in the Git repository.

    Args:
        file_path: Absolute or relative path to the file.

    Returns:
        str: Commit hash if changes were made.

    Raises:
        FileNotFoundError: If the file does not exist.
        ValueError: If the file is outside the repository.
        GitCommandError: If Git operations fail.
    """
    file_path = file_path.resolve()

    # Ensure the file is within the repository
    try:
        rel_path = file_path.relative_to(self.repo_path)
    except ValueError as e:
        raise ValueError(
            f"File {file_path} is not under the repository root {self.repo_path}"
        ) from e

    if not file_path.exists():
        raise FileNotFoundError(f"File does not exist: {file_path}")

    try:
        return self._commit_file_update(rel_path, file_path)
    except GitCommandError as e:
        logger.error(f"Git operation failed: {e}")
        raise

LocalPromptManager

A simple singleton implementation of PromptManager that ensures only one instance is created and reused throughout the application lifecycle.

This class wraps the PromptManager to provide efficient prompt loading by maintaining a single reusable instance.

Attributes:

Name Type Description
_instance Optional[SingletonPromptManager]

The singleton instance

_prompt_manager Optional[PromptManager]

The wrapped PromptManager instance

Source code in src/tnh_scholar/ai_text_processing/prompts.py
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
class LocalPromptManager:
    """
    A simple singleton implementation of PromptManager that ensures only one instance
    is created and reused throughout the application lifecycle.

    This class wraps the PromptManager to provide efficient prompt loading by
    maintaining a single reusable instance.

    Attributes:
        _instance (Optional[SingletonPromptManager]): The singleton instance
        _prompt_manager (Optional[PromptManager]): The wrapped PromptManager instance
    """

    _instance: Optional["LocalPromptManager"] = None

    def __new__(cls) -> "LocalPromptManager":
        """
        Create or return the singleton instance.

        Returns:
            SingletonPromptManager: The singleton instance
        """
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._prompt_manager = None
        return cls._instance

    @property
    def prompt_manager(self) -> "PromptCatalog":
        """
        Lazy initialization of the PromptManager instance.

        Returns:
            PromptManager: The wrapped PromptManager instance

        Raises:
            RuntimeError: If PATTERN_REPO is not properly configured
        """
        if self._prompt_manager is None:  # type: ignore
            try:
                load_dotenv()
                if prompt_path_name := os.getenv("TNH_PATTERN_DIR"):
                    prompt_dir = Path(prompt_path_name)
                    logger.debug(f"prompt dir: {prompt_path_name}")
                else:
                    prompt_dir = TNH_DEFAULT_PATTERN_DIR
                self._prompt_manager = PromptCatalog(prompt_dir)
            except ImportError as err:
                raise RuntimeError(
                    "Failed to initialize PromptManager. Ensure prompt_manager "
                    f"module and PATTERN_REPO are properly configured: {err}"
                ) from err
        return self._prompt_manager

    def get_prompt(self, name: str) -> Prompt:
        """Get a prompt by name."""
        return self.prompt_manager.load(Prompt._normalize_name(name))
prompt_manager property

Lazy initialization of the PromptManager instance.

Returns:

Name Type Description
PromptManager PromptCatalog

The wrapped PromptManager instance

Raises:

Type Description
RuntimeError

If PATTERN_REPO is not properly configured

__new__()

Create or return the singleton instance.

Returns:

Name Type Description
SingletonPromptManager LocalPromptManager

The singleton instance

Source code in src/tnh_scholar/ai_text_processing/prompts.py
946
947
948
949
950
951
952
953
954
955
956
def __new__(cls) -> "LocalPromptManager":
    """
    Create or return the singleton instance.

    Returns:
        SingletonPromptManager: The singleton instance
    """
    if cls._instance is None:
        cls._instance = super().__new__(cls)
        cls._instance._prompt_manager = None
    return cls._instance
get_prompt(name)

Get a prompt by name.

Source code in src/tnh_scholar/ai_text_processing/prompts.py
985
986
987
def get_prompt(self, name: str) -> Prompt:
    """Get a prompt by name."""
    return self.prompt_manager.load(Prompt._normalize_name(name))

LogicalSection

Bases: BaseModel

Represents a contextually meaningful segment of a larger text.

Sections should preserve natural breaks in content (explicit section markers, topic shifts, argument development, narrative progression) while staying within specified size limits in order to create chunks suitable for AI processing.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class LogicalSection(BaseModel):
    """
    Represents a contextually meaningful segment of a larger text.

    Sections should preserve natural breaks in content 
    (explicit section markers, topic shifts, argument development, narrative progression) 
    while staying within specified size limits in order to create chunks suitable for AI processing.
    """  # noqa: E501
    start_line: int = Field(
        ..., 
        description="Starting line number that begins this logical segment"
    )
    title: str = Field(
        ...,
        description="Descriptive title of section's key content"
    )
start_line = Field(..., description='Starting line number that begins this logical segment') class-attribute instance-attribute
title = Field(..., description="Descriptive title of section's key content") class-attribute instance-attribute

OpenAIProcessor

Bases: TextProcessor

OpenAI-based text processor implementation.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class OpenAIProcessor(TextProcessor):
    """OpenAI-based text processor implementation."""
    def __init__(self, model: Optional[str] = None, max_tokens: int = 0):
        if not model:
            model = DEFAULT_OPENAI_MODEL
        self.model = model
        self.max_tokens = max_tokens

    def process_text(
        self,
        input_str: str,
        instructions: str,
        response_format: Optional[Type[BaseModel]] = None,
        max_tokens: int = 0,
        **kwargs,
    ) -> ProcessorResult:
        """Process text using OpenAI API with optional structured output."""

        if max_tokens == 0 and self.max_tokens > 0:
            max_tokens = self.max_tokens

        return openai_process_text(
            input_str,
            instructions,
            model=self.model,
            max_tokens=max_tokens,
            response_format=response_format,
            **kwargs,
        )
max_tokens = max_tokens instance-attribute
model = model instance-attribute
__init__(model=None, max_tokens=0)
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
78
79
80
81
82
def __init__(self, model: Optional[str] = None, max_tokens: int = 0):
    if not model:
        model = DEFAULT_OPENAI_MODEL
    self.model = model
    self.max_tokens = max_tokens
process_text(input_str, instructions, response_format=None, max_tokens=0, **kwargs)

Process text using OpenAI API with optional structured output.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def process_text(
    self,
    input_str: str,
    instructions: str,
    response_format: Optional[Type[BaseModel]] = None,
    max_tokens: int = 0,
    **kwargs,
) -> ProcessorResult:
    """Process text using OpenAI API with optional structured output."""

    if max_tokens == 0 and self.max_tokens > 0:
        max_tokens = self.max_tokens

    return openai_process_text(
        input_str,
        instructions,
        model=self.model,
        max_tokens=max_tokens,
        response_format=response_format,
        **kwargs,
    )

Prompt

Base Prompt class for version-controlled template prompts.

Prompts contain: - Instructions: The main prompt instructions as a Jinja2 template. Note: Instructions are intended to be saved in markdown format in a .md file. - Template fields: Default values for template variables - Metadata: Name and identifier information

Version control is handled externally through Git, not in the prompt itself. Prompt identity is determined by the combination of identifiers.

Attributes:

Name Type Description
name str

The name of the prompt

instructions str

The Jinja2 template string for this prompt

default_template_fields Dict[str, str]

Default values for template variables

_allow_empty_vars bool

Whether to allow undefined template variables

_env Environment

Configured Jinja2 environment instance

Source code in src/tnh_scholar/ai_text_processing/prompts.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
class Prompt:
    """
    Base Prompt class for version-controlled template prompts.

    Prompts contain:
    - Instructions: The main prompt instructions as a Jinja2 template.
       Note: Instructions are intended to be saved in markdown format in a .md file.
    - Template fields: Default values for template variables
    - Metadata: Name and identifier information

    Version control is handled externally through Git, not in the prompt itself.
    Prompt identity is determined by the combination of identifiers.

    Attributes:
        name (str): The name of the prompt
        instructions (str): The Jinja2 template string for this prompt
        default_template_fields (Dict[str, str]): Default values for template variables
        _allow_empty_vars (bool): Whether to allow undefined template variables
        _env (Environment): Configured Jinja2 environment instance
    """

    @staticmethod
    def _normalize_name(value: str) -> str:
        """Canonicalize prompt names for case-insensitive handling.

        Currently: strip() + lower(). If future rules are needed (e.g.,
        removing punctuation, limiting length), implement them here.
        """
        return value.strip().lower()

    def __init__(
        self,
        name: str,
        instructions: MarkdownStr,
        path: Optional[Path] = None,
        default_template_fields: Optional[Dict[str, str]] = None,
        allow_empty_vars: bool = False,        
    ) -> None:
        """
        Initialize a new Prompt instance.

        Args:
            name: Unique name identifying the prompt
            instructions: Jinja2 template string containing the prompt
            default_template_fields: Optional default values for template variables
            allow_empty_vars: Whether to allow undefined template variables

        Raises:
            ValueError: If name or instructions are empty
            TemplateError: If template syntax is invalid
        """
        if not name or not instructions:
            raise ValueError("Name and instructions must not be empty")

        # Normalize prompt name to lowercase for case-insensitive handling
        name = Prompt._normalize_name(name)

        self.name = name
        self.instructions = instructions
        self.path = path
        self.default_template_fields = default_template_fields or {}
        self._allow_empty_vars = allow_empty_vars
        self._env = self._create_environment()

        # Validate template syntax on initialization
        self._validate_template()

    @staticmethod
    def _create_environment() -> Environment:
        """
        Create and configure a Jinja2 environment with optimal settings.

        Returns:
            Environment: Configured Jinja2 environment 
            with security and formatting options
        """
        return Environment(
            undefined=StrictUndefined,  # Raise errors for undefined variables
            trim_blocks=True,  # Remove first newline after a block
            lstrip_blocks=True,  # Strip tabs and spaces from the start of lines
            autoescape=True,  # Enable autoescaping for security
        )

    def _validate_template(self) -> None:
        """
        Validate the template syntax without rendering.

        Raises:
            TemplateError: If template syntax is invalid
        """
        try:
            self._env.parse(self.instructions)
        except TemplateError as e:
            raise TemplateError(
                f"Invalid template syntax in prompt '{self.name}': {str(e)}"
            ) from e

    def apply_template(self, field_values: Optional[Dict[str, str]] = None) -> str:
        """
        Apply template values to prompt instructions using Jinja2.

        Values precedence (highest to lowest):
        1. field_values (explicitly passed)
        2. frontmatter values (from prompt file)
        3. default_template_fields (prompt defaults)

        Args:
            field_values: Values to substitute into the template.
                        If None, uses frontmatter/defaults.

        Returns:
            str: Rendered instructions with template values applied.

        Raises:
            TemplateError: If template rendering fails
            ValueError: If required template variables are missing
        """
        # Get frontmatter values
        frontmatter = self.extract_frontmatter() or {}

        # Combine values with correct precedence using | operator
        template_values = self.default_template_fields | \
            frontmatter | (field_values or {})

        instructions = self.get_content_without_frontmatter()
        logger.debug(f"instructions without frontmatter:\n{instructions}")

        try:
            return self._render_template_with_values(instructions, template_values)
        except TemplateError as e:
            raise TemplateError(
                f"Template rendering failed for prompt '{self.name}': {str(e)}"
                ) from e

    def _render_template_with_values(
        self, 
        instructions: str, 
        template_values: dict
        ) -> str:
        """
        Validate and render template with provided values.

        Args:
            instructions: Template content without frontmatter
            template_values: Values to substitute into template

        Returns:
            Rendered template string

        Raises:
            ValueError: If required template variables are missing
        """
        # Parse for validation
        parsed_content = self._env.parse(instructions)
        required_vars = find_undeclared_variables(parsed_content)

        # Validate variables
        missing_vars = required_vars - set(template_values.keys())
        if missing_vars and not self._allow_empty_vars:
            raise ValueError(
                f"Missing required template variables in prompt '{self.name}': "
                f"{', '.join(sorted(missing_vars))}"
            )

        # Create and render template
        template = self._env.from_string(instructions)
        return template.render(**template_values)

    def extract_frontmatter(self) -> Optional[Dict[str, Any]]:
        """
        Extract and validate YAML frontmatter from markdown instructions.

        Returns:
            Optional[Dict]: Frontmatter data if found and valid, None otherwise

        Note:
            Frontmatter must be at the very start of the file and properly formatted.
        """

        prompt = r"\A---\s*\n(.*?)\n---\s*(?:\n|$)"
        if match := re.match(prompt, self.instructions, re.DOTALL):
            try:
                frontmatter = yaml.safe_load(match[1])
                if frontmatter is None:
                    return None
                if not isinstance(frontmatter, dict):
                    logger.warning(f"Frontmatter must be a YAML dictionary: "
                                   f"{frontmatter}")
                    return None
                return frontmatter
            except yaml.YAMLError as e:
                logger.warning(f"Invalid YAML in frontmatter: {e}")
                return None
        return None

    def get_content_without_frontmatter(self) -> str:
        """
        Get markdown content with frontmatter removed.

        Returns:
            str: Markdown content without frontmatter
        """
        prompt = r"\A---\s*\n.*?\n---\s*\n"
        return re.sub(prompt, "", self.instructions, flags=re.DOTALL)

    def update_frontmatter(self, new_data: Dict[str, Any]) -> None:
        """
        Update or add frontmatter to the markdown content.

        Args:
            new_data: Dictionary of frontmatter fields to update
        """

        current_frontmatter = self.extract_frontmatter() or {}
        updated_frontmatter = {**current_frontmatter, **new_data}

        # Create YAML string
        yaml_str = yaml.dump(
            updated_frontmatter, default_flow_style=False, allow_unicode=True
        )

        # Remove existing frontmatter if present
        content = self.get_content_without_frontmatter()

        # Combine new frontmatter with content
        self.instructions = f"---\n{yaml_str}---\n\n{content}"


    def source_bytes(self) -> bytes:
        """
        Best-effort raw bytes for prompt hashing.

        Prefers hashing exact on-disk bytes including front-matter.
        We therefore first try to read from `prompt_path`. If that fails, we fall back
        to hashing the concatenation of known templates. In V1, only
        the instructions (system template) are used for rendering.
        """
        # Preferred path: use on-disk bytes when available.
        if self.path is not None:
            return self.path.read_bytes()

        # Fallback: concatenate known templates deterministically
        sys_part = self.instructions or ""
        return sys_part.encode("utf-8")

    def content_hash(self) -> str:
        """
        Generate a SHA-256 hash of the prompt content.

        Useful for quick content comparison and change detection.

        Returns:
            str: Hexadecimal string of the SHA-256 hash
        """
        content = (
            f"{self.name}{self.instructions}"
            f"{sorted(self.default_template_fields.items())}"
            )
        return hashlib.sha256(content.encode("utf-8")).hexdigest()

    def to_dict(self) -> Dict[str, Any]:
        """
        Convert prompt to dictionary for serialization.

        Returns:
            Dict containing all prompt data in serializable format
        """
        return {
            "name": self.name,
            "instructions": self.instructions,
            "default_template_fields": self.default_template_fields,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "Prompt":
        """
        Create prompt instance from dictionary data.

        Args:
            data: Dictionary containing prompt data

        Returns:
            Prompt: New prompt instance

        Raises:
            ValueError: If required fields are missing
        """
        required_fields = {"name", "instructions"}
        if missing_fields := required_fields - set(data.keys()):
            raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")

        return cls(
            name=Prompt._normalize_name(str(data["name"])),
            instructions=data["instructions"],
            path=None,
            default_template_fields=data.get("default_template_fields", {}),
        )

    def __eq__(self, other: object) -> bool:
        """Compare prompts based on their content."""
        if not isinstance(other, Prompt):
            return NotImplemented
        return self.content_hash() == other.content_hash()

    def __hash__(self) -> int:
        """Hash based on content hash for container operations."""
        return hash(self.content_hash())
default_template_fields = default_template_fields or {} instance-attribute
instructions = instructions instance-attribute
name = name instance-attribute
path = path instance-attribute
__eq__(other)

Compare prompts based on their content.

Source code in src/tnh_scholar/ai_text_processing/prompts.py
326
327
328
329
330
def __eq__(self, other: object) -> bool:
    """Compare prompts based on their content."""
    if not isinstance(other, Prompt):
        return NotImplemented
    return self.content_hash() == other.content_hash()
__hash__()

Hash based on content hash for container operations.

Source code in src/tnh_scholar/ai_text_processing/prompts.py
332
333
334
def __hash__(self) -> int:
    """Hash based on content hash for container operations."""
    return hash(self.content_hash())
__init__(name, instructions, path=None, default_template_fields=None, allow_empty_vars=False)

Initialize a new Prompt instance.

Parameters:

Name Type Description Default
name str

Unique name identifying the prompt

required
instructions MarkdownStr

Jinja2 template string containing the prompt

required
default_template_fields Optional[Dict[str, str]]

Optional default values for template variables

None
allow_empty_vars bool

Whether to allow undefined template variables

False

Raises:

Type Description
ValueError

If name or instructions are empty

TemplateError

If template syntax is invalid

Source code in src/tnh_scholar/ai_text_processing/prompts.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def __init__(
    self,
    name: str,
    instructions: MarkdownStr,
    path: Optional[Path] = None,
    default_template_fields: Optional[Dict[str, str]] = None,
    allow_empty_vars: bool = False,        
) -> None:
    """
    Initialize a new Prompt instance.

    Args:
        name: Unique name identifying the prompt
        instructions: Jinja2 template string containing the prompt
        default_template_fields: Optional default values for template variables
        allow_empty_vars: Whether to allow undefined template variables

    Raises:
        ValueError: If name or instructions are empty
        TemplateError: If template syntax is invalid
    """
    if not name or not instructions:
        raise ValueError("Name and instructions must not be empty")

    # Normalize prompt name to lowercase for case-insensitive handling
    name = Prompt._normalize_name(name)

    self.name = name
    self.instructions = instructions
    self.path = path
    self.default_template_fields = default_template_fields or {}
    self._allow_empty_vars = allow_empty_vars
    self._env = self._create_environment()

    # Validate template syntax on initialization
    self._validate_template()
apply_template(field_values=None)

Apply template values to prompt instructions using Jinja2.

Values precedence (highest to lowest): 1. field_values (explicitly passed) 2. frontmatter values (from prompt file) 3. default_template_fields (prompt defaults)

Parameters:

Name Type Description Default
field_values Optional[Dict[str, str]]

Values to substitute into the template. If None, uses frontmatter/defaults.

None

Returns:

Name Type Description
str str

Rendered instructions with template values applied.

Raises:

Type Description
TemplateError

If template rendering fails

ValueError

If required template variables are missing

Source code in src/tnh_scholar/ai_text_processing/prompts.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def apply_template(self, field_values: Optional[Dict[str, str]] = None) -> str:
    """
    Apply template values to prompt instructions using Jinja2.

    Values precedence (highest to lowest):
    1. field_values (explicitly passed)
    2. frontmatter values (from prompt file)
    3. default_template_fields (prompt defaults)

    Args:
        field_values: Values to substitute into the template.
                    If None, uses frontmatter/defaults.

    Returns:
        str: Rendered instructions with template values applied.

    Raises:
        TemplateError: If template rendering fails
        ValueError: If required template variables are missing
    """
    # Get frontmatter values
    frontmatter = self.extract_frontmatter() or {}

    # Combine values with correct precedence using | operator
    template_values = self.default_template_fields | \
        frontmatter | (field_values or {})

    instructions = self.get_content_without_frontmatter()
    logger.debug(f"instructions without frontmatter:\n{instructions}")

    try:
        return self._render_template_with_values(instructions, template_values)
    except TemplateError as e:
        raise TemplateError(
            f"Template rendering failed for prompt '{self.name}': {str(e)}"
            ) from e
content_hash()

Generate a SHA-256 hash of the prompt content.

Useful for quick content comparison and change detection.

Returns:

Name Type Description
str str

Hexadecimal string of the SHA-256 hash

Source code in src/tnh_scholar/ai_text_processing/prompts.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def content_hash(self) -> str:
    """
    Generate a SHA-256 hash of the prompt content.

    Useful for quick content comparison and change detection.

    Returns:
        str: Hexadecimal string of the SHA-256 hash
    """
    content = (
        f"{self.name}{self.instructions}"
        f"{sorted(self.default_template_fields.items())}"
        )
    return hashlib.sha256(content.encode("utf-8")).hexdigest()
extract_frontmatter()

Extract and validate YAML frontmatter from markdown instructions.

Returns:

Type Description
Optional[Dict[str, Any]]

Optional[Dict]: Frontmatter data if found and valid, None otherwise

Note

Frontmatter must be at the very start of the file and properly formatted.

Source code in src/tnh_scholar/ai_text_processing/prompts.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def extract_frontmatter(self) -> Optional[Dict[str, Any]]:
    """
    Extract and validate YAML frontmatter from markdown instructions.

    Returns:
        Optional[Dict]: Frontmatter data if found and valid, None otherwise

    Note:
        Frontmatter must be at the very start of the file and properly formatted.
    """

    prompt = r"\A---\s*\n(.*?)\n---\s*(?:\n|$)"
    if match := re.match(prompt, self.instructions, re.DOTALL):
        try:
            frontmatter = yaml.safe_load(match[1])
            if frontmatter is None:
                return None
            if not isinstance(frontmatter, dict):
                logger.warning(f"Frontmatter must be a YAML dictionary: "
                               f"{frontmatter}")
                return None
            return frontmatter
        except yaml.YAMLError as e:
            logger.warning(f"Invalid YAML in frontmatter: {e}")
            return None
    return None
from_dict(data) classmethod

Create prompt instance from dictionary data.

Parameters:

Name Type Description Default
data Dict[str, Any]

Dictionary containing prompt data

required

Returns:

Name Type Description
Prompt Prompt

New prompt instance

Raises:

Type Description
ValueError

If required fields are missing

Source code in src/tnh_scholar/ai_text_processing/prompts.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Prompt":
    """
    Create prompt instance from dictionary data.

    Args:
        data: Dictionary containing prompt data

    Returns:
        Prompt: New prompt instance

    Raises:
        ValueError: If required fields are missing
    """
    required_fields = {"name", "instructions"}
    if missing_fields := required_fields - set(data.keys()):
        raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")

    return cls(
        name=Prompt._normalize_name(str(data["name"])),
        instructions=data["instructions"],
        path=None,
        default_template_fields=data.get("default_template_fields", {}),
    )
get_content_without_frontmatter()

Get markdown content with frontmatter removed.

Returns:

Name Type Description
str str

Markdown content without frontmatter

Source code in src/tnh_scholar/ai_text_processing/prompts.py
223
224
225
226
227
228
229
230
231
def get_content_without_frontmatter(self) -> str:
    """
    Get markdown content with frontmatter removed.

    Returns:
        str: Markdown content without frontmatter
    """
    prompt = r"\A---\s*\n.*?\n---\s*\n"
    return re.sub(prompt, "", self.instructions, flags=re.DOTALL)
source_bytes()

Best-effort raw bytes for prompt hashing.

Prefers hashing exact on-disk bytes including front-matter. We therefore first try to read from prompt_path. If that fails, we fall back to hashing the concatenation of known templates. In V1, only the instructions (system template) are used for rendering.

Source code in src/tnh_scholar/ai_text_processing/prompts.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def source_bytes(self) -> bytes:
    """
    Best-effort raw bytes for prompt hashing.

    Prefers hashing exact on-disk bytes including front-matter.
    We therefore first try to read from `prompt_path`. If that fails, we fall back
    to hashing the concatenation of known templates. In V1, only
    the instructions (system template) are used for rendering.
    """
    # Preferred path: use on-disk bytes when available.
    if self.path is not None:
        return self.path.read_bytes()

    # Fallback: concatenate known templates deterministically
    sys_part = self.instructions or ""
    return sys_part.encode("utf-8")
to_dict()

Convert prompt to dictionary for serialization.

Returns:

Type Description
Dict[str, Any]

Dict containing all prompt data in serializable format

Source code in src/tnh_scholar/ai_text_processing/prompts.py
288
289
290
291
292
293
294
295
296
297
298
299
def to_dict(self) -> Dict[str, Any]:
    """
    Convert prompt to dictionary for serialization.

    Returns:
        Dict containing all prompt data in serializable format
    """
    return {
        "name": self.name,
        "instructions": self.instructions,
        "default_template_fields": self.default_template_fields,
    }
update_frontmatter(new_data)

Update or add frontmatter to the markdown content.

Parameters:

Name Type Description Default
new_data Dict[str, Any]

Dictionary of frontmatter fields to update

required
Source code in src/tnh_scholar/ai_text_processing/prompts.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def update_frontmatter(self, new_data: Dict[str, Any]) -> None:
    """
    Update or add frontmatter to the markdown content.

    Args:
        new_data: Dictionary of frontmatter fields to update
    """

    current_frontmatter = self.extract_frontmatter() or {}
    updated_frontmatter = {**current_frontmatter, **new_data}

    # Create YAML string
    yaml_str = yaml.dump(
        updated_frontmatter, default_flow_style=False, allow_unicode=True
    )

    # Remove existing frontmatter if present
    content = self.get_content_without_frontmatter()

    # Combine new frontmatter with content
    self.instructions = f"---\n{yaml_str}---\n\n{content}"

PromptCatalog

Main interface for prompt management system.

Provides high-level operations: - Prompt creation and loading - Automatic versioning - Safe concurrent access - Basic history tracking - Case-insensitive prompt names (stored as lowercase)

Source code in src/tnh_scholar/ai_text_processing/prompts.py
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
class PromptCatalog:
    """
    Main interface for prompt management system.

    Provides high-level operations:
    - Prompt creation and loading
    - Automatic versioning
    - Safe concurrent access
    - Basic history tracking
    - Case-insensitive prompt names (stored as lowercase)
    """

    def __init__(self, base_path: Path):
        """
        Initialize prompt management system.

        Args:
            base_path: Base directory for prompt storage
        """
        self.base_path = Path(base_path).resolve()
        self.base_path.mkdir(parents=True, exist_ok=True)

        # Initialize subsystems
        self.repo = GitBackedRepository(self.base_path)
        self.access_manager = ConcurrentAccessManager(self.base_path / ".locks")

        logger.info(f"Initialized prompt management system at {base_path}")

    def _normalize_path(self, path: Union[str, Path]) -> Path:
        """
        Normalize a path to be absolute under the repository base path.

        Handles these cases to same result:
        - "my_file" -> <base_path>/my_file
        - "<base_path>/my_file" -> <base_path>/my_file

        Args:
            path: Input path as string or Path

        Returns:
            Path: Absolute path under base_path

        Raises:
            ValueError: If path would resolve outside repository
        """
        path = Path(path)  # ensure we have a path

        # Join with base_path as needed: always interpret relative
        # paths as relative to the repository base path. This avoids
        # incorrectly handling nested relative paths like "a/b"
        # which may not have the same parent as self.base_path.
        if not path.is_absolute():
            path = self.base_path / path

        # Safety check after resolution
        resolved = path.resolve()
        try:
            resolved.relative_to(self.base_path)
        except ValueError as e:
            raise ValueError(
                f"Path {path} resolves outside repository: {self.base_path}"
            ) from e

        return resolved

    def get_path(self, prompt_name: str) -> Optional[Path]:
        """
        Recursively search for a prompt file with the given name (case-insensitive)
        in base_path and all subdirectories.

        Args:
            prompt_name: prompt name (without extension) to search for

        Returns:
            Optional[Path]: Full path to the found prompt file, or None if not found
        """
        target = Prompt._normalize_name(prompt_name)
        with suppress(StopIteration):
            for path in self.base_path.rglob("*.md"):
                if path.is_file() and path.stem.lower() == target:
                    logger.debug(
                        f"Found prompt file for name {prompt_name} at: {path}"
                    )
                    return self._normalize_path(path)
        logger.debug(f"No prompt file found with name: {prompt_name}")
        return None

    def save(self, prompt: Prompt, subdir: Optional[Path] = None) -> Path:
        prompt_name = Prompt._normalize_name(prompt.name)
        instructions = prompt.instructions

        if subdir is None:
            path = self.base_path / f"{prompt_name}.md"
        else:
            path = self.base_path / subdir / f"{prompt_name}.md"

        path = self._normalize_path(path)

        # Check for existing prompt by case-insensitive match
        existing_path = self.get_path(prompt_name)

        try:
            # Lock on the destination path name (lowercase) to avoid races
            with self.access_manager.file_lock(path):
                # If an existing file is present but at a different case/path, rename it
                if existing_path is not None and existing_path != path:
                    path.parent.mkdir(parents=True, exist_ok=True)
                    logger.info(
                        f"Renaming existing prompt file from {existing_path} to {path} "
                        "to enforce lowercase naming."
                    )
                    existing_path.rename(path)

                write_str_to_file(path, instructions, overwrite=True)
                self.repo.update_file(path)
                logger.info(f"Prompt saved at {path}")
                return path.relative_to(self.base_path)

        except Exception as e:
            logger.error(f"Failed to save prompt {prompt_name}: {e}")
            raise

    def load(self, prompt_name: str) -> Prompt:
        """
        Load the .md prompt file by name, extract placeholders, and
        return a fully constructed Prompt object.

        Args:
            prompt_name: Name of the prompt (without .md extension).

        Returns:
            A new Prompt object whose 'instructions' is the file's text
            and whose 'template_fields' are inferred from placeholders in
            those instructions.
        """
        prompt_name = Prompt._normalize_name(prompt_name)
        # Locate the .md file; raise if missing
        path = self.get_path(prompt_name)
        if not path:
            raise FileNotFoundError(f"No prompt file named {prompt_name}.md found in prompt catalog:\n"
                                    f"{self.base_path}"
                                    )

        # Acquire lock before reading
        with self.access_manager.file_lock(path):
            instructions = read_str_from_file(path)

        instructions = MarkdownStr(instructions)

        # Create the prompt from the raw .md text (name is already lowercase)
        prompt = Prompt(name=prompt_name, instructions=instructions, path=path)

        # Check for local uncommitted changes, updating file:
        self.repo.update_file(path)

        return prompt

    def show_history(self, prompt_name: str) -> None:
        if path := self.get_path(prompt_name):
            self.repo.display_history(path)
        else:
            logger.error(f"Path to {prompt_name} not found.")
            return

    # def get_prompt_history_from_path(self, path: Path) -> List[Dict[str, Any]]:
    #     """
    #     Get version history for a prompt.

    #     Args:
    #         path: Path to prompt file

    #     Returns:
    #         List of version information
    #     """
    #     path = self._normalize_path(path)

    #     return self.repo.get_history(path)

    @classmethod
    def verify_repository(cls, base_path: Path) -> bool:
        """
        Verify repository integrity and uniqueness of prompt names.

        Performs the following checks:
        1. Validates Git repository structure.
        2. Ensures no duplicate prompt names exist.

        Args:
            base_path: Repository path to verify.

        Returns:
            bool: True if the repository is valid 
            and contains no duplicate prompt files.
        """
        try:
            # Check if it's a valid Git repository
            repo = Repo(base_path)

            # Verify basic repository structure
            basic_valid = (
                repo.head.is_valid()
                and not repo.bare
                and (base_path / ".git").is_dir()
                and (base_path / ".locks").is_dir()
            )

            if not basic_valid:
                return False

            prompt_files = list(base_path.rglob("*.md"))
            seen_names: Dict[str, Path] = {}

            for prompt_file in prompt_files:
                # Skip files in .git directory
                if ".git" in prompt_file.parts:
                    continue

                # Case-insensitive key
                key = Prompt._normalize_name(prompt_file.stem)

                if key in seen_names:
                    logger.error(
                        f"Duplicate prompt file detected (case-insensitive):\n"
                        f"  First occurrence: {seen_names[key]}\n"
                        f"  Second occurrence: {prompt_file}"
                    )
                    return False

                seen_names[key] = prompt_file

            return True

        except (InvalidGitRepositoryError, Exception) as e:
            logger.error(f"Repository verification failed: {e}")
            return False
access_manager = ConcurrentAccessManager(self.base_path / '.locks') instance-attribute
base_path = Path(base_path).resolve() instance-attribute
repo = GitBackedRepository(self.base_path) instance-attribute
__init__(base_path)

Initialize prompt management system.

Parameters:

Name Type Description Default
base_path Path

Base directory for prompt storage

required
Source code in src/tnh_scholar/ai_text_processing/prompts.py
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
def __init__(self, base_path: Path):
    """
    Initialize prompt management system.

    Args:
        base_path: Base directory for prompt storage
    """
    self.base_path = Path(base_path).resolve()
    self.base_path.mkdir(parents=True, exist_ok=True)

    # Initialize subsystems
    self.repo = GitBackedRepository(self.base_path)
    self.access_manager = ConcurrentAccessManager(self.base_path / ".locks")

    logger.info(f"Initialized prompt management system at {base_path}")
get_path(prompt_name)

Recursively search for a prompt file with the given name (case-insensitive) in base_path and all subdirectories.

Parameters:

Name Type Description Default
prompt_name str

prompt name (without extension) to search for

required

Returns:

Type Description
Optional[Path]

Optional[Path]: Full path to the found prompt file, or None if not found

Source code in src/tnh_scholar/ai_text_processing/prompts.py
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
def get_path(self, prompt_name: str) -> Optional[Path]:
    """
    Recursively search for a prompt file with the given name (case-insensitive)
    in base_path and all subdirectories.

    Args:
        prompt_name: prompt name (without extension) to search for

    Returns:
        Optional[Path]: Full path to the found prompt file, or None if not found
    """
    target = Prompt._normalize_name(prompt_name)
    with suppress(StopIteration):
        for path in self.base_path.rglob("*.md"):
            if path.is_file() and path.stem.lower() == target:
                logger.debug(
                    f"Found prompt file for name {prompt_name} at: {path}"
                )
                return self._normalize_path(path)
    logger.debug(f"No prompt file found with name: {prompt_name}")
    return None
load(prompt_name)

Load the .md prompt file by name, extract placeholders, and return a fully constructed Prompt object.

Parameters:

Name Type Description Default
prompt_name str

Name of the prompt (without .md extension).

required

Returns:

Type Description
Prompt

A new Prompt object whose 'instructions' is the file's text

Prompt

and whose 'template_fields' are inferred from placeholders in

Prompt

those instructions.

Source code in src/tnh_scholar/ai_text_processing/prompts.py
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
def load(self, prompt_name: str) -> Prompt:
    """
    Load the .md prompt file by name, extract placeholders, and
    return a fully constructed Prompt object.

    Args:
        prompt_name: Name of the prompt (without .md extension).

    Returns:
        A new Prompt object whose 'instructions' is the file's text
        and whose 'template_fields' are inferred from placeholders in
        those instructions.
    """
    prompt_name = Prompt._normalize_name(prompt_name)
    # Locate the .md file; raise if missing
    path = self.get_path(prompt_name)
    if not path:
        raise FileNotFoundError(f"No prompt file named {prompt_name}.md found in prompt catalog:\n"
                                f"{self.base_path}"
                                )

    # Acquire lock before reading
    with self.access_manager.file_lock(path):
        instructions = read_str_from_file(path)

    instructions = MarkdownStr(instructions)

    # Create the prompt from the raw .md text (name is already lowercase)
    prompt = Prompt(name=prompt_name, instructions=instructions, path=path)

    # Check for local uncommitted changes, updating file:
    self.repo.update_file(path)

    return prompt
save(prompt, subdir=None)
Source code in src/tnh_scholar/ai_text_processing/prompts.py
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
def save(self, prompt: Prompt, subdir: Optional[Path] = None) -> Path:
    prompt_name = Prompt._normalize_name(prompt.name)
    instructions = prompt.instructions

    if subdir is None:
        path = self.base_path / f"{prompt_name}.md"
    else:
        path = self.base_path / subdir / f"{prompt_name}.md"

    path = self._normalize_path(path)

    # Check for existing prompt by case-insensitive match
    existing_path = self.get_path(prompt_name)

    try:
        # Lock on the destination path name (lowercase) to avoid races
        with self.access_manager.file_lock(path):
            # If an existing file is present but at a different case/path, rename it
            if existing_path is not None and existing_path != path:
                path.parent.mkdir(parents=True, exist_ok=True)
                logger.info(
                    f"Renaming existing prompt file from {existing_path} to {path} "
                    "to enforce lowercase naming."
                )
                existing_path.rename(path)

            write_str_to_file(path, instructions, overwrite=True)
            self.repo.update_file(path)
            logger.info(f"Prompt saved at {path}")
            return path.relative_to(self.base_path)

    except Exception as e:
        logger.error(f"Failed to save prompt {prompt_name}: {e}")
        raise
show_history(prompt_name)
Source code in src/tnh_scholar/ai_text_processing/prompts.py
852
853
854
855
856
857
def show_history(self, prompt_name: str) -> None:
    if path := self.get_path(prompt_name):
        self.repo.display_history(path)
    else:
        logger.error(f"Path to {prompt_name} not found.")
        return
verify_repository(base_path) classmethod

Verify repository integrity and uniqueness of prompt names.

Performs the following checks: 1. Validates Git repository structure. 2. Ensures no duplicate prompt names exist.

Parameters:

Name Type Description Default
base_path Path

Repository path to verify.

required

Returns:

Name Type Description
bool bool

True if the repository is valid

bool

and contains no duplicate prompt files.

Source code in src/tnh_scholar/ai_text_processing/prompts.py
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
@classmethod
def verify_repository(cls, base_path: Path) -> bool:
    """
    Verify repository integrity and uniqueness of prompt names.

    Performs the following checks:
    1. Validates Git repository structure.
    2. Ensures no duplicate prompt names exist.

    Args:
        base_path: Repository path to verify.

    Returns:
        bool: True if the repository is valid 
        and contains no duplicate prompt files.
    """
    try:
        # Check if it's a valid Git repository
        repo = Repo(base_path)

        # Verify basic repository structure
        basic_valid = (
            repo.head.is_valid()
            and not repo.bare
            and (base_path / ".git").is_dir()
            and (base_path / ".locks").is_dir()
        )

        if not basic_valid:
            return False

        prompt_files = list(base_path.rglob("*.md"))
        seen_names: Dict[str, Path] = {}

        for prompt_file in prompt_files:
            # Skip files in .git directory
            if ".git" in prompt_file.parts:
                continue

            # Case-insensitive key
            key = Prompt._normalize_name(prompt_file.stem)

            if key in seen_names:
                logger.error(
                    f"Duplicate prompt file detected (case-insensitive):\n"
                    f"  First occurrence: {seen_names[key]}\n"
                    f"  Second occurrence: {prompt_file}"
                )
                return False

            seen_names[key] = prompt_file

        return True

    except (InvalidGitRepositoryError, Exception) as e:
        logger.error(f"Repository verification failed: {e}")
        return False

SectionEntry

Bases: NamedTuple

Represents a section with its content during iteration.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
72
73
74
75
76
77
class SectionEntry(NamedTuple):
    """Represents a section with its content during iteration."""
    number: int         # Logical Section number (1 based index)
    title: str          # Section title 
    content: str        # Section content
    range: SectionRange # Section range
content instance-attribute
number instance-attribute
range instance-attribute
title instance-attribute

SectionParser

Generates structured section breakdowns of text content.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
class SectionParser:
    """Generates structured section breakdowns of text content."""

    def __init__(
        self,
        section_scanner: TextProcessor,
        section_pattern: Prompt,
        review_count: int = DEFAULT_REVIEW_COUNT,
    ):
        """
        Initialize section generator.

        Args:
            processor: Implementation of TextProcessor
            pattern: Pattern object containing section generation instructions
            max_tokens: Maximum tokens for response
            section_count: Target number of sections
            review_count: Number of review passes
        """
        self.section_scanner = section_scanner
        self.section_pattern = section_pattern
        self.review_count = review_count

    def find_sections(
        self,
        text: TextObject,
        section_count_target: Optional[int] = None,
        segment_size_target: Optional[int] = None,
        template_dict: Optional[Dict[str, str]] = None,
    ) -> TextObject:
        """
        Generate section breakdown of input text. The text must be split up by newlines.

        Args:
            text: Input TextObject to process
            section_count_target: the target for the number of sections to find
            segment_size_target: the target for the number of lines per section
                (if section_count_target is specified, 
                this value will be set to generate correct segments)
            template_dict: Optional additional template variables

        Returns:
            TextObject containing section breakdown
        """

        # Prepare numbered text, each line is numbered
        num_text = text.num_text

        if num_text.size < SECTION_SEGMENT_SIZE_WARNING_LIMIT:
            logger.warning(
                f"find_sections: Text has only {num_text.size} lines. "
                "This may lead to unexpected sectioning results."
            )

        # Get language if not specified
        source_language = get_language_from_code(text.language)

        # determine section count if not specified
        if not section_count_target:
            segment_size_target, section_count_target = self._get_section_count_info(
                text.content
            )
        elif not segment_size_target:
            segment_size_target = round(num_text.size / section_count_target)

        section_count_range = self._get_section_count_range(section_count_target)

        current_metadata = text.metadata

        # Prepare template variables
        template_values = {
            "metadata": current_metadata.to_yaml(),
            "source_language": source_language,
            "section_count": section_count_range,
            "line_count": segment_size_target,
            "review_count": self.review_count,
        }

        if template_dict:
            template_values |= template_dict

        # Get and apply processing instructions
        instructions = self.section_pattern.apply_template(template_values)
        logger.debug(f"Finding sections with pattern instructions:\n {instructions}")

        logger.info(
            f"Finding sections for {source_language} text "
            f"(target sections: {section_count_target})"
        )

        # Process text with structured output
        result = self.section_scanner.process_text(
            num_text.numbered_content, instructions, response_format=AIResponse
        )

        ai_response = cast(AIResponse, result)
        text_result = TextObject.from_response(ai_response, current_metadata, num_text)

        logger.info(f"Generated {text_result.section_count} sections.")

        return text_result

    def _get_section_count_info(self, text: str) -> Tuple[int, int]:
        num_text = NumberedText(text)
        segment_size = _calculate_segment_size(num_text, DEFAULT_SECTION_TOKEN_SIZE)
        section_count_target = round(num_text.size / segment_size)
        return segment_size, section_count_target

    def _get_section_count_range(
        self,
        section_count_target: int,
        section_range_var: int = DEFAULT_SECTION_RANGE_VAR,
    ) -> str:
        low = max(1, section_count_target - section_range_var)
        high = section_count_target + section_range_var
        return f"{low}-{high}"
review_count = review_count instance-attribute
section_pattern = section_pattern instance-attribute
section_scanner = section_scanner instance-attribute
__init__(section_scanner, section_pattern, review_count=DEFAULT_REVIEW_COUNT)

Initialize section generator.

Parameters:

Name Type Description Default
processor

Implementation of TextProcessor

required
pattern

Pattern object containing section generation instructions

required
max_tokens

Maximum tokens for response

required
section_count

Target number of sections

required
review_count int

Number of review passes

DEFAULT_REVIEW_COUNT
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def __init__(
    self,
    section_scanner: TextProcessor,
    section_pattern: Prompt,
    review_count: int = DEFAULT_REVIEW_COUNT,
):
    """
    Initialize section generator.

    Args:
        processor: Implementation of TextProcessor
        pattern: Pattern object containing section generation instructions
        max_tokens: Maximum tokens for response
        section_count: Target number of sections
        review_count: Number of review passes
    """
    self.section_scanner = section_scanner
    self.section_pattern = section_pattern
    self.review_count = review_count
find_sections(text, section_count_target=None, segment_size_target=None, template_dict=None)

Generate section breakdown of input text. The text must be split up by newlines.

Parameters:

Name Type Description Default
text TextObject

Input TextObject to process

required
section_count_target Optional[int]

the target for the number of sections to find

None
segment_size_target Optional[int]

the target for the number of lines per section (if section_count_target is specified, this value will be set to generate correct segments)

None
template_dict Optional[Dict[str, str]]

Optional additional template variables

None

Returns:

Type Description
TextObject

TextObject containing section breakdown

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def find_sections(
    self,
    text: TextObject,
    section_count_target: Optional[int] = None,
    segment_size_target: Optional[int] = None,
    template_dict: Optional[Dict[str, str]] = None,
) -> TextObject:
    """
    Generate section breakdown of input text. The text must be split up by newlines.

    Args:
        text: Input TextObject to process
        section_count_target: the target for the number of sections to find
        segment_size_target: the target for the number of lines per section
            (if section_count_target is specified, 
            this value will be set to generate correct segments)
        template_dict: Optional additional template variables

    Returns:
        TextObject containing section breakdown
    """

    # Prepare numbered text, each line is numbered
    num_text = text.num_text

    if num_text.size < SECTION_SEGMENT_SIZE_WARNING_LIMIT:
        logger.warning(
            f"find_sections: Text has only {num_text.size} lines. "
            "This may lead to unexpected sectioning results."
        )

    # Get language if not specified
    source_language = get_language_from_code(text.language)

    # determine section count if not specified
    if not section_count_target:
        segment_size_target, section_count_target = self._get_section_count_info(
            text.content
        )
    elif not segment_size_target:
        segment_size_target = round(num_text.size / section_count_target)

    section_count_range = self._get_section_count_range(section_count_target)

    current_metadata = text.metadata

    # Prepare template variables
    template_values = {
        "metadata": current_metadata.to_yaml(),
        "source_language": source_language,
        "section_count": section_count_range,
        "line_count": segment_size_target,
        "review_count": self.review_count,
    }

    if template_dict:
        template_values |= template_dict

    # Get and apply processing instructions
    instructions = self.section_pattern.apply_template(template_values)
    logger.debug(f"Finding sections with pattern instructions:\n {instructions}")

    logger.info(
        f"Finding sections for {source_language} text "
        f"(target sections: {section_count_target})"
    )

    # Process text with structured output
    result = self.section_scanner.process_text(
        num_text.numbered_content, instructions, response_format=AIResponse
    )

    ai_response = cast(AIResponse, result)
    text_result = TextObject.from_response(ai_response, current_metadata, num_text)

    logger.info(f"Generated {text_result.section_count} sections.")

    return text_result

SectionProcessor

Handles section-based XML text processing with configurable output handling.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
class SectionProcessor:
    """Handles section-based XML text processing with configurable output handling."""

    def __init__(
        self,
        processor: TextProcessor,
        pattern: Prompt,
        template_dict: Dict,
        wrap_in_document: bool = True,
    ):
        """
        Initialize the XML section processor.

        Args:
            processor: Implementation of TextProcessor to use
            pattern: Pattern object containing processing instructions
            template_dict: Dictionary for template substitution
            wrap_in_document: Whether to wrap output in <document> tags
        """
        self.processor = processor
        self.pattern = pattern
        self.template_dict = template_dict
        self.wrap_in_document = wrap_in_document

    def process_sections(
        self,
        text_object: TextObject,
    ) -> Generator[ProcessedSection, None, None]:
        """
        Process transcript sections and yield results one section at a time.

        Args:
            transcript: Text to process
            text_object: Object containing section definitions

        Yields:
            ProcessedSection: One processed section at a time, containing:
                - title: Section title (English or original language)
                - original_text: Raw text segment
                - processed_text: Processed text content
                - start_line: Starting line number
        """
        # numbered_transcript = NumberedText(transcript) 
        # transcript is now stored in the TextObject
        sections = text_object.sections

        logger.info(
            f"Processing {len(sections)} sections with pattern: {self.pattern.name}"
        )

        for section_entry in text_object:
            logger.info(f"Processing section {section_entry.number} "
                        f"'{section_entry.title}':")

            # Get text segment for section
            text_segment = section_entry.content

            # Prepare template variables
            template_values = {
                "metadata": text_object.metadata.to_yaml(),
                "section_title": section_entry.title,
                "source_language": get_language_from_code(text_object.language),
                "review_count": DEFAULT_REVIEW_COUNT,
            }

            if self.template_dict:
                template_values |= self.template_dict

            # Get and apply processing instructions
            instructions = self.pattern.apply_template(template_values)
            processed_str = self.processor.process_text(text_segment, instructions)

            yield ProcessedSection(
                title=section_entry.title,
                original_str=text_segment,
                processed_str=processed_str,
            )

    def process_paragraphs(
        self,
        text: TextObject,
    ) -> Generator[ProcessedSection, None, None]:
        """
        Process transcript by paragraphs (as sections), yielding ProcessedSection objects.
        Paragraphs are assumed to be given as newline separated.

        Args:
            text: TextObject to process

        Yields:
            ProcessedSection: One processed paragraph at a time, containing:
                - title: Paragraph number (e.g., 'Paragraph 1')
                - original_str: Raw paragraph text
                - processed_str: Processed paragraph text
                - metadata: Optional metadata dict
        """
        num_text = text.num_text

        logger.info(f"Processing lines as paragraphs with pattern: {self.pattern.name}")

        for i, line in num_text:
            # If line is empty or whitespace, continue
            if not line.strip():
                continue

            instructions = self.pattern.apply_template(self.template_dict)

            if i <= 1:
                logger.debug(f"Process instructions (first paragraph):\n{instructions}")

            processed_str = self.processor.process_text(line, instructions)
            yield ProcessedSection(
                title=f"Paragraph {i}",
                original_str=line,
                processed_str=processed_str,
                metadata={"paragraph_number": i}
            )
pattern = pattern instance-attribute
processor = processor instance-attribute
template_dict = template_dict instance-attribute
wrap_in_document = wrap_in_document instance-attribute
__init__(processor, pattern, template_dict, wrap_in_document=True)

Initialize the XML section processor.

Parameters:

Name Type Description Default
processor TextProcessor

Implementation of TextProcessor to use

required
pattern Prompt

Pattern object containing processing instructions

required
template_dict Dict

Dictionary for template substitution

required
wrap_in_document bool

Whether to wrap output in tags

True
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def __init__(
    self,
    processor: TextProcessor,
    pattern: Prompt,
    template_dict: Dict,
    wrap_in_document: bool = True,
):
    """
    Initialize the XML section processor.

    Args:
        processor: Implementation of TextProcessor to use
        pattern: Pattern object containing processing instructions
        template_dict: Dictionary for template substitution
        wrap_in_document: Whether to wrap output in <document> tags
    """
    self.processor = processor
    self.pattern = pattern
    self.template_dict = template_dict
    self.wrap_in_document = wrap_in_document
process_paragraphs(text)

Process transcript by paragraphs (as sections), yielding ProcessedSection objects. Paragraphs are assumed to be given as newline separated.

Parameters:

Name Type Description Default
text TextObject

TextObject to process

required

Yields:

Name Type Description
ProcessedSection ProcessedSection

One processed paragraph at a time, containing: - title: Paragraph number (e.g., 'Paragraph 1') - original_str: Raw paragraph text - processed_str: Processed paragraph text - metadata: Optional metadata dict

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def process_paragraphs(
    self,
    text: TextObject,
) -> Generator[ProcessedSection, None, None]:
    """
    Process transcript by paragraphs (as sections), yielding ProcessedSection objects.
    Paragraphs are assumed to be given as newline separated.

    Args:
        text: TextObject to process

    Yields:
        ProcessedSection: One processed paragraph at a time, containing:
            - title: Paragraph number (e.g., 'Paragraph 1')
            - original_str: Raw paragraph text
            - processed_str: Processed paragraph text
            - metadata: Optional metadata dict
    """
    num_text = text.num_text

    logger.info(f"Processing lines as paragraphs with pattern: {self.pattern.name}")

    for i, line in num_text:
        # If line is empty or whitespace, continue
        if not line.strip():
            continue

        instructions = self.pattern.apply_template(self.template_dict)

        if i <= 1:
            logger.debug(f"Process instructions (first paragraph):\n{instructions}")

        processed_str = self.processor.process_text(line, instructions)
        yield ProcessedSection(
            title=f"Paragraph {i}",
            original_str=line,
            processed_str=processed_str,
            metadata={"paragraph_number": i}
        )
process_sections(text_object)

Process transcript sections and yield results one section at a time.

Parameters:

Name Type Description Default
transcript

Text to process

required
text_object TextObject

Object containing section definitions

required

Yields:

Name Type Description
ProcessedSection ProcessedSection

One processed section at a time, containing: - title: Section title (English or original language) - original_text: Raw text segment - processed_text: Processed text content - start_line: Starting line number

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
def process_sections(
    self,
    text_object: TextObject,
) -> Generator[ProcessedSection, None, None]:
    """
    Process transcript sections and yield results one section at a time.

    Args:
        transcript: Text to process
        text_object: Object containing section definitions

    Yields:
        ProcessedSection: One processed section at a time, containing:
            - title: Section title (English or original language)
            - original_text: Raw text segment
            - processed_text: Processed text content
            - start_line: Starting line number
    """
    # numbered_transcript = NumberedText(transcript) 
    # transcript is now stored in the TextObject
    sections = text_object.sections

    logger.info(
        f"Processing {len(sections)} sections with pattern: {self.pattern.name}"
    )

    for section_entry in text_object:
        logger.info(f"Processing section {section_entry.number} "
                    f"'{section_entry.title}':")

        # Get text segment for section
        text_segment = section_entry.content

        # Prepare template variables
        template_values = {
            "metadata": text_object.metadata.to_yaml(),
            "section_title": section_entry.title,
            "source_language": get_language_from_code(text_object.language),
            "review_count": DEFAULT_REVIEW_COUNT,
        }

        if self.template_dict:
            template_values |= self.template_dict

        # Get and apply processing instructions
        instructions = self.pattern.apply_template(template_values)
        processed_str = self.processor.process_text(text_segment, instructions)

        yield ProcessedSection(
            title=section_entry.title,
            original_str=text_segment,
            processed_str=processed_str,
        )

TextObject

Manages text content with section organization and metadata tracking.

TextObject serves as the core container for text processing, providing: - Line-numbered text content management - Language identification - Section organization and access - Metadata tracking including incorporated processing stages

The class allows for section boundaries through line numbering, allowing sections to be defined by start lines without explicit end lines. Subsequent sections implicitly end where the next section begins. SectionObjects are utilized to represent sections.

Attributes:

Name Type Description
num_text NumberedText

Line-numbered text content manager

language str

ISO 639-1 language code for the text content

_sections List[SectionObject]

Internal list of text sections with boundaries

_metadata Metadata

Processing and content metadata container

Example

content = NumberedText("Line 1\nLine 2\nLine 3") obj = TextObject(content, language="en")

Source code in src/tnh_scholar/ai_text_processing/text_object.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
class TextObject:
    """
    Manages text content with section organization and metadata tracking.

    TextObject serves as the core container for text processing, providing:
    - Line-numbered text content management
    - Language identification
    - Section organization and access
    - Metadata tracking including incorporated processing stages

    The class allows for section boundaries through line numbering,
    allowing sections to be defined by start lines without explicit end lines.
    Subsequent sections implicitly end where the next section begins.
    SectionObjects are utilized to represent sections.

    Attributes:
        num_text: Line-numbered text content manager
        language: ISO 639-1 language code for the text content
        _sections: Internal list of text sections with boundaries
        _metadata: Processing and content metadata container

    Example:
        >>> content = NumberedText("Line 1\\nLine 2\\nLine 3")
        >>> obj = TextObject(content, language="en")
    """
    num_text: NumberedText 
    language: str 
    _sections: List[SectionObject]
    _metadata: Metadata

    def __init__(self, 
        num_text: NumberedText, 
        language: Optional[str] = None, 
        sections: Optional[List[SectionObject]] = None,
        metadata: Optional[Metadata] = None):
        """
        Initialize a TextObject with content and optional organizing components.

        Args:
            num_text: Text content with line numbering
            language: ISO 639-1 language code. If None, auto-detected from content
            sections: Initial sections defining text organization. If None, 
                      text is considered un-sectioned.
            metadata: Initial metadata. If None, creates empty metadata container

        Note:
            Until sections are established, section-based methods will raise a value
            error if called.
        """
        self.num_text = num_text
        self.language = language or get_language_code_from_text(num_text.content)
        self._sections = sections or []
        self._metadata = metadata or Metadata()

        if sections:
            self.validate_sections()


    def __iter__(self) -> Iterator[SectionEntry]:
        """Iterate through sections, yielding full section information."""
        if not self._sections:
            raise ValueError("No Sections available.")

        for i, section in enumerate(self._sections):
            content = self.num_text.get_segment(
                section.section_range.start, 
                section.section_range.end
            )
            yield SectionEntry(
                number=i+1,
                title=section.title,
                range=section.section_range,
                content=content
            )

    def __str__(self) -> str:
        return Frontmatter.embed(self.metadata, self.content)

    @staticmethod
    def _build_section_objects(
        logical_sections: List[LogicalSection], 
        last_line: int,
        metadata: Optional[Metadata] = None
    ) -> List[SectionObject]:
        """Convert LogicalSections to SectionObjects with proper ranges."""
        section_objects = []

        for i, section in enumerate(logical_sections):
            # For each section, end is either next section's start or last line + 1
            end_line = (logical_sections[i + 1].start_line 
                    if i < len(logical_sections) - 1 
                    else last_line + 1)

            section_objects.append(
                SectionObject.from_logical_section(section, end_line, metadata)
            )

        return section_objects

    @classmethod
    def from_str(
        cls,
        text: str,
        language: Optional[str] = None,
        sections: Optional[List[SectionObject]] = None,
        metadata: Optional[Metadata] = None
    ) -> 'TextObject':
        """
        Create a TextObject from a string, extracting any frontmatter.

        Args:
            text: Input text string, potentially containing frontmatter
            language: ISO language code
            sections: List of section objects
            metadata: Optional base metadata to merge with frontmatter

        Returns:
            TextObject instance with combined metadata
        """
        # Extract any frontmatter and merge with provided metadata
        frontmatter_metadata, content = Frontmatter.extract(text)

        # Create NumberedText from content without frontmatter
        numbered_text = NumberedText(content)

        obj = cls(
            num_text=numbered_text,
            language=language,
            sections=sections,
            metadata=frontmatter_metadata
        )
        if metadata:
            obj.merge_metadata(metadata)

        return obj


    @classmethod
    def from_response(
        cls, 
        response: AIResponse,
        existing_metadata: Metadata,
        num_text: 'NumberedText'
    ) -> 'TextObject':
        """Create TextObject from AI response format."""
        # Create metadata from response
        ai_metadata = response.document_metadata
        new_metadata = Metadata({
            "ai_summary": response.document_summary,
            "ai_concepts": response.key_concepts,
            "ai_context": response.narrative_context
        })

        # Convert LogicalSections to SectionObjects
        sections = cls._build_section_objects(
            response.sections, 
            num_text.size,
        )

        text = cls(
            num_text=num_text,
            language=response.language,
            sections=sections,
            metadata=existing_metadata
        )
        text.merge_metadata(new_metadata)
        text.merge_metadata(Metadata.from_yaml(ai_metadata))
        return text

    def merge_metadata(self, new_metadata: Metadata, override=False) -> None:
        """
        Merge new metadata with existing metadata.

        For now, performs simple dict-like union (|=) but can be extended 
        to handle more complex merging logic in the future (e.g., merging 
        nested structures, handling conflicts, merging arrays).

        Args:
        new_metadata: Metadata to merge with existing metadata
        override: If True, new_metadata values override existing values
                            If False, existing values are preserved
        """
        # Currently using simple dict union
        # Future implementations might handle:
        # - Deep merging of nested structures
        # - Special handling of specific fields
        # - Array/list merging strategies
        # - Conflict resolution
        # - Metadata versioning
        if not new_metadata:
            return

        if override:
            self._metadata |= new_metadata  # new overrides existing
        else:
            self._metadata = new_metadata | self._metadata # existing values preserved

        logger.debug("Merging new metadata into TextObject")

    def update_metadata(self, **kwargs) -> None:
        """Update metadata with new key-value pairs."""
        new_metadata = Metadata(kwargs)
        self.merge_metadata(new_metadata)

    def validate_sections(self) -> None:
        """Basic validation of section integrity."""
        if not self._sections:
            raise ValueError("No sections set.")

        # Check section ordering and bounds
        for i, section in enumerate(self._sections):
            if section.section_range.start < 1:
                logger.warning(f"Section {i}: start line must be >= 1")
            if section.section_range.start > self.num_text.size:
                logger.warning(f"Section {i}: start line exceeds text length")
            if i > 0 and \
                section.section_range.start <= self._sections[i-1].section_range.start:
                logger.warning(f"Section {i}: non-sequential start line")

    def get_section_content(self, index: int) -> str:     
        if not self._sections:
            raise ValueError("No Sections available.")
        """Get content for a section."""            
        if index < 0 or index >= len(self._sections):
            raise IndexError("Section index out of range")

        section = self._sections[index]
        return self.num_text.get_segment(
            section.section_range.start, 
            section.section_range.end
        )

    def export_info(self, source_file: Optional[Path] = None) -> TextObjectInfo:
        """Export serializable state."""
        if source_file:
            source_file = source_file.resolve() # use absolute path for info

        return TextObjectInfo(
            source_file=source_file,
            language=self.language,
            sections=self.sections,
            metadata=self.metadata
        )

    @classmethod
    def from_info(
        cls, 
        info: TextObjectInfo, 
        metadata: Metadata, 
        num_text: 'NumberedText'
        ) -> 'TextObject':
        """Create TextObject from info and content."""
        text_obj = cls(
            num_text=num_text, 
            language=info.language, 
            sections=info.sections, 
            metadata=info.metadata
            )

        text_obj.merge_metadata(metadata)
        return text_obj

    @classmethod
    def from_text_file(
        cls,
        file: Path
    ) -> 'TextObject':
        text_str = read_str_from_file(file)
        return cls.from_str(text_str)

    @classmethod
    def from_section_file(
        cls, 
        section_file: Path, 
        source: Optional[str] = None
        ) -> 'TextObject':
        """
        Create TextObject from a section info file, loading content from source_file.
        Metadata is extracted from the source_file or from content.

        Args:
            section_file: Path to JSON file containing TextObjectInfo
            source: Optional source string in case no source file is found.

        Returns:
            TextObject instance

        Raises:
            ValueError: If source_file is missing from section info
            FileNotFoundError: If either section_file or source_file not found
        """
        # Check section file exists
        if not section_file.exists():
            raise FileNotFoundError(f"Section file not found: {section_file}")

        # Load and parse section info
        info = TextObjectInfo.model_validate_json(read_str_from_file(section_file))

        if not source:  # passed content always takes precedence over source_file
            # check if source file exists
            if not info.source_file:
                raise ValueError(f"No content available: no source_file specified "
                                 f"in section info: {section_file}")

            source_path = Path(info.source_file)
            if not source_path.exists():
                raise FileNotFoundError(
                    f"No content available: Source file not found: {source_path}"
                    )

            # Load source from path
            source = read_str_from_file(source_path)

        metadata, content = Frontmatter.extract(source)

        # Create TextObject
        return cls.from_info(info=info, 
                             metadata=metadata, 
                             num_text=NumberedText(content)
                             )

    def save(
        self,
        path: Path,
        output_format: StorageFormatType = StorageFormat.TEXT,
        source_file: Optional[Path] = None,
        pretty: bool = True
        ) -> None:
        """
        Save TextObject to file in specified format.

        Args:
            path: Output file path
            format: "text" for full content+metadata or "json" for serialized state
            source_file: Optional source file to record in metadata
            pretty: For JSON output, whether to pretty print
        """
        if isinstance(output_format, str):
            output_format = StorageFormat(output_format)

        if output_format == StorageFormat.TEXT:
            # Full text output with metadata as frontmatter
            write_str_to_file(path, str(self))

        elif output_format == StorageFormat.JSON:
            # Export serializable state
            info = self.export_info(source_file)
            json_str = info.model_dump_json(indent=2 if pretty else None)
            write_str_to_file(path, json_str)

    @classmethod
    def load(
        cls,
        path: Path,
        config: Optional[LoadConfig] = None
    ) -> 'TextObject':
        """
        Load TextObject from file with optional configuration.

        Args:
            path: Input file path
            config: Optional loading configuration. If not provided,
                loads directly from text file.

        Returns:
            TextObject instance

        Usage:
            # Load from text file with frontmatter
            obj = TextObject.load(Path("content.txt"))

            # Load state from JSON with source content string
            config = LoadConfig(
                format=StorageFormat.JSON,
                source_content="Text content..."
            )
            obj = TextObject.load(Path("state.json"), config)

            # Load state from JSON with source content file
            config = LoadConfig(
                format=StorageFormat.JSON,
                source_content=Path("content.txt")
            )
            obj = TextObject.load(Path("state.json"), config)
        """
        # Use default config if none provided
        config = config or LoadConfig()

        if config.format == StorageFormat.TEXT:
            return cls.from_text_file(path)

        elif config.format == StorageFormat.JSON:
            return cls.from_section_file(path, source=config.get_source_text())

        else:
            raise ValueError("Unknown load configuration format.")

    def transform(
        self,
        data_str: Optional[str] = None,
        language: Optional[str] = None, 
        metadata: Optional[Metadata] = None,
        process_metadata: Optional[ProcessMetadata] = None,
        sections: Optional[List[SectionObject]] = None
    ) -> Self:
        """Update TextObject content and metadata in place.

        Optionally modifies the object's content, language, and adds process tracking.
        Process history is maintained in metadata.

        Args:
            content: New text content
            language: New language code  
            process_tag: Identifier for the process performed
        """
        # Update potentially changed elements
        if data_str:
            self.num_text = NumberedText(data_str)
        if language:
            self.language = language
        if metadata:
            self.merge_metadata(metadata)
        if process_metadata:    
            self._metadata.add_process_info(process_metadata)
        if sections:
            self._sections = sections

        return self

    @property
    def metadata(self) -> Metadata:
        """Access to metadata dictionary."""
        return self._metadata  

    @property
    def section_count(self) -> int:
        return len(self._sections) if self._sections else 0

    @property
    def last_line_num(self) -> int:
        return self.num_text.size

    @property
    def sections(self) -> List[SectionObject]:
        """Access to sections list."""
        return self._sections or []

    @property
    def content(self) -> str:
        return self.num_text.content

    @property
    def metadata_str(self) -> str:
        return self.metadata.to_yaml()

    @property
    def numbered_content(self) -> str:
        return self.num_text.numbered_content
content property
language = language or get_language_code_from_text(num_text.content) instance-attribute
last_line_num property
metadata property

Access to metadata dictionary.

metadata_str property
num_text = num_text instance-attribute
numbered_content property
section_count property
sections property

Access to sections list.

__init__(num_text, language=None, sections=None, metadata=None)

Initialize a TextObject with content and optional organizing components.

Parameters:

Name Type Description Default
num_text NumberedText

Text content with line numbering

required
language Optional[str]

ISO 639-1 language code. If None, auto-detected from content

None
sections Optional[List[SectionObject]]

Initial sections defining text organization. If None, text is considered un-sectioned.

None
metadata Optional[Metadata]

Initial metadata. If None, creates empty metadata container

None
Note

Until sections are established, section-based methods will raise a value error if called.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def __init__(self, 
    num_text: NumberedText, 
    language: Optional[str] = None, 
    sections: Optional[List[SectionObject]] = None,
    metadata: Optional[Metadata] = None):
    """
    Initialize a TextObject with content and optional organizing components.

    Args:
        num_text: Text content with line numbering
        language: ISO 639-1 language code. If None, auto-detected from content
        sections: Initial sections defining text organization. If None, 
                  text is considered un-sectioned.
        metadata: Initial metadata. If None, creates empty metadata container

    Note:
        Until sections are established, section-based methods will raise a value
        error if called.
    """
    self.num_text = num_text
    self.language = language or get_language_code_from_text(num_text.content)
    self._sections = sections or []
    self._metadata = metadata or Metadata()

    if sections:
        self.validate_sections()
__iter__()

Iterate through sections, yielding full section information.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def __iter__(self) -> Iterator[SectionEntry]:
    """Iterate through sections, yielding full section information."""
    if not self._sections:
        raise ValueError("No Sections available.")

    for i, section in enumerate(self._sections):
        content = self.num_text.get_segment(
            section.section_range.start, 
            section.section_range.end
        )
        yield SectionEntry(
            number=i+1,
            title=section.title,
            range=section.section_range,
            content=content
        )
__str__()
Source code in src/tnh_scholar/ai_text_processing/text_object.py
230
231
def __str__(self) -> str:
    return Frontmatter.embed(self.metadata, self.content)
export_info(source_file=None)

Export serializable state.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
387
388
389
390
391
392
393
394
395
396
397
def export_info(self, source_file: Optional[Path] = None) -> TextObjectInfo:
    """Export serializable state."""
    if source_file:
        source_file = source_file.resolve() # use absolute path for info

    return TextObjectInfo(
        source_file=source_file,
        language=self.language,
        sections=self.sections,
        metadata=self.metadata
    )
from_info(info, metadata, num_text) classmethod

Create TextObject from info and content.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
@classmethod
def from_info(
    cls, 
    info: TextObjectInfo, 
    metadata: Metadata, 
    num_text: 'NumberedText'
    ) -> 'TextObject':
    """Create TextObject from info and content."""
    text_obj = cls(
        num_text=num_text, 
        language=info.language, 
        sections=info.sections, 
        metadata=info.metadata
        )

    text_obj.merge_metadata(metadata)
    return text_obj
from_response(response, existing_metadata, num_text) classmethod

Create TextObject from AI response format.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
@classmethod
def from_response(
    cls, 
    response: AIResponse,
    existing_metadata: Metadata,
    num_text: 'NumberedText'
) -> 'TextObject':
    """Create TextObject from AI response format."""
    # Create metadata from response
    ai_metadata = response.document_metadata
    new_metadata = Metadata({
        "ai_summary": response.document_summary,
        "ai_concepts": response.key_concepts,
        "ai_context": response.narrative_context
    })

    # Convert LogicalSections to SectionObjects
    sections = cls._build_section_objects(
        response.sections, 
        num_text.size,
    )

    text = cls(
        num_text=num_text,
        language=response.language,
        sections=sections,
        metadata=existing_metadata
    )
    text.merge_metadata(new_metadata)
    text.merge_metadata(Metadata.from_yaml(ai_metadata))
    return text
from_section_file(section_file, source=None) classmethod

Create TextObject from a section info file, loading content from source_file. Metadata is extracted from the source_file or from content.

Parameters:

Name Type Description Default
section_file Path

Path to JSON file containing TextObjectInfo

required
source Optional[str]

Optional source string in case no source file is found.

None

Returns:

Type Description
TextObject

TextObject instance

Raises:

Type Description
ValueError

If source_file is missing from section info

FileNotFoundError

If either section_file or source_file not found

Source code in src/tnh_scholar/ai_text_processing/text_object.py
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
@classmethod
def from_section_file(
    cls, 
    section_file: Path, 
    source: Optional[str] = None
    ) -> 'TextObject':
    """
    Create TextObject from a section info file, loading content from source_file.
    Metadata is extracted from the source_file or from content.

    Args:
        section_file: Path to JSON file containing TextObjectInfo
        source: Optional source string in case no source file is found.

    Returns:
        TextObject instance

    Raises:
        ValueError: If source_file is missing from section info
        FileNotFoundError: If either section_file or source_file not found
    """
    # Check section file exists
    if not section_file.exists():
        raise FileNotFoundError(f"Section file not found: {section_file}")

    # Load and parse section info
    info = TextObjectInfo.model_validate_json(read_str_from_file(section_file))

    if not source:  # passed content always takes precedence over source_file
        # check if source file exists
        if not info.source_file:
            raise ValueError(f"No content available: no source_file specified "
                             f"in section info: {section_file}")

        source_path = Path(info.source_file)
        if not source_path.exists():
            raise FileNotFoundError(
                f"No content available: Source file not found: {source_path}"
                )

        # Load source from path
        source = read_str_from_file(source_path)

    metadata, content = Frontmatter.extract(source)

    # Create TextObject
    return cls.from_info(info=info, 
                         metadata=metadata, 
                         num_text=NumberedText(content)
                         )
from_str(text, language=None, sections=None, metadata=None) classmethod

Create a TextObject from a string, extracting any frontmatter.

Parameters:

Name Type Description Default
text str

Input text string, potentially containing frontmatter

required
language Optional[str]

ISO language code

None
sections Optional[List[SectionObject]]

List of section objects

None
metadata Optional[Metadata]

Optional base metadata to merge with frontmatter

None

Returns:

Type Description
TextObject

TextObject instance with combined metadata

Source code in src/tnh_scholar/ai_text_processing/text_object.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
@classmethod
def from_str(
    cls,
    text: str,
    language: Optional[str] = None,
    sections: Optional[List[SectionObject]] = None,
    metadata: Optional[Metadata] = None
) -> 'TextObject':
    """
    Create a TextObject from a string, extracting any frontmatter.

    Args:
        text: Input text string, potentially containing frontmatter
        language: ISO language code
        sections: List of section objects
        metadata: Optional base metadata to merge with frontmatter

    Returns:
        TextObject instance with combined metadata
    """
    # Extract any frontmatter and merge with provided metadata
    frontmatter_metadata, content = Frontmatter.extract(text)

    # Create NumberedText from content without frontmatter
    numbered_text = NumberedText(content)

    obj = cls(
        num_text=numbered_text,
        language=language,
        sections=sections,
        metadata=frontmatter_metadata
    )
    if metadata:
        obj.merge_metadata(metadata)

    return obj
from_text_file(file) classmethod
Source code in src/tnh_scholar/ai_text_processing/text_object.py
417
418
419
420
421
422
423
@classmethod
def from_text_file(
    cls,
    file: Path
) -> 'TextObject':
    text_str = read_str_from_file(file)
    return cls.from_str(text_str)
get_section_content(index)
Source code in src/tnh_scholar/ai_text_processing/text_object.py
374
375
376
377
378
379
380
381
382
383
384
385
def get_section_content(self, index: int) -> str:     
    if not self._sections:
        raise ValueError("No Sections available.")
    """Get content for a section."""            
    if index < 0 or index >= len(self._sections):
        raise IndexError("Section index out of range")

    section = self._sections[index]
    return self.num_text.get_segment(
        section.section_range.start, 
        section.section_range.end
    )
load(path, config=None) classmethod

Load TextObject from file with optional configuration.

Parameters:

Name Type Description Default
path Path

Input file path

required
config Optional[LoadConfig]

Optional loading configuration. If not provided, loads directly from text file.

None

Returns:

Type Description
TextObject

TextObject instance

Usage
Load from text file with frontmatter

obj = TextObject.load(Path("content.txt"))

Load state from JSON with source content string

config = LoadConfig( format=StorageFormat.JSON, source_content="Text content..." ) obj = TextObject.load(Path("state.json"), config)

Load state from JSON with source content file

config = LoadConfig( format=StorageFormat.JSON, source_content=Path("content.txt") ) obj = TextObject.load(Path("state.json"), config)

Source code in src/tnh_scholar/ai_text_processing/text_object.py
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
@classmethod
def load(
    cls,
    path: Path,
    config: Optional[LoadConfig] = None
) -> 'TextObject':
    """
    Load TextObject from file with optional configuration.

    Args:
        path: Input file path
        config: Optional loading configuration. If not provided,
            loads directly from text file.

    Returns:
        TextObject instance

    Usage:
        # Load from text file with frontmatter
        obj = TextObject.load(Path("content.txt"))

        # Load state from JSON with source content string
        config = LoadConfig(
            format=StorageFormat.JSON,
            source_content="Text content..."
        )
        obj = TextObject.load(Path("state.json"), config)

        # Load state from JSON with source content file
        config = LoadConfig(
            format=StorageFormat.JSON,
            source_content=Path("content.txt")
        )
        obj = TextObject.load(Path("state.json"), config)
    """
    # Use default config if none provided
    config = config or LoadConfig()

    if config.format == StorageFormat.TEXT:
        return cls.from_text_file(path)

    elif config.format == StorageFormat.JSON:
        return cls.from_section_file(path, source=config.get_source_text())

    else:
        raise ValueError("Unknown load configuration format.")
merge_metadata(new_metadata, override=False)

Merge new metadata with existing metadata.

For now, performs simple dict-like union (|=) but can be extended to handle more complex merging logic in the future (e.g., merging nested structures, handling conflicts, merging arrays).

Args: new_metadata: Metadata to merge with existing metadata override: If True, new_metadata values override existing values If False, existing values are preserved

Source code in src/tnh_scholar/ai_text_processing/text_object.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
def merge_metadata(self, new_metadata: Metadata, override=False) -> None:
    """
    Merge new metadata with existing metadata.

    For now, performs simple dict-like union (|=) but can be extended 
    to handle more complex merging logic in the future (e.g., merging 
    nested structures, handling conflicts, merging arrays).

    Args:
    new_metadata: Metadata to merge with existing metadata
    override: If True, new_metadata values override existing values
                        If False, existing values are preserved
    """
    # Currently using simple dict union
    # Future implementations might handle:
    # - Deep merging of nested structures
    # - Special handling of specific fields
    # - Array/list merging strategies
    # - Conflict resolution
    # - Metadata versioning
    if not new_metadata:
        return

    if override:
        self._metadata |= new_metadata  # new overrides existing
    else:
        self._metadata = new_metadata | self._metadata # existing values preserved

    logger.debug("Merging new metadata into TextObject")
save(path, output_format=StorageFormat.TEXT, source_file=None, pretty=True)

Save TextObject to file in specified format.

Parameters:

Name Type Description Default
path Path

Output file path

required
format

"text" for full content+metadata or "json" for serialized state

required
source_file Optional[Path]

Optional source file to record in metadata

None
pretty bool

For JSON output, whether to pretty print

True
Source code in src/tnh_scholar/ai_text_processing/text_object.py
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
def save(
    self,
    path: Path,
    output_format: StorageFormatType = StorageFormat.TEXT,
    source_file: Optional[Path] = None,
    pretty: bool = True
    ) -> None:
    """
    Save TextObject to file in specified format.

    Args:
        path: Output file path
        format: "text" for full content+metadata or "json" for serialized state
        source_file: Optional source file to record in metadata
        pretty: For JSON output, whether to pretty print
    """
    if isinstance(output_format, str):
        output_format = StorageFormat(output_format)

    if output_format == StorageFormat.TEXT:
        # Full text output with metadata as frontmatter
        write_str_to_file(path, str(self))

    elif output_format == StorageFormat.JSON:
        # Export serializable state
        info = self.export_info(source_file)
        json_str = info.model_dump_json(indent=2 if pretty else None)
        write_str_to_file(path, json_str)
transform(data_str=None, language=None, metadata=None, process_metadata=None, sections=None)

Update TextObject content and metadata in place.

Optionally modifies the object's content, language, and adds process tracking. Process history is maintained in metadata.

Parameters:

Name Type Description Default
content

New text content

required
language Optional[str]

New language code

None
process_tag

Identifier for the process performed

required
Source code in src/tnh_scholar/ai_text_processing/text_object.py
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
def transform(
    self,
    data_str: Optional[str] = None,
    language: Optional[str] = None, 
    metadata: Optional[Metadata] = None,
    process_metadata: Optional[ProcessMetadata] = None,
    sections: Optional[List[SectionObject]] = None
) -> Self:
    """Update TextObject content and metadata in place.

    Optionally modifies the object's content, language, and adds process tracking.
    Process history is maintained in metadata.

    Args:
        content: New text content
        language: New language code  
        process_tag: Identifier for the process performed
    """
    # Update potentially changed elements
    if data_str:
        self.num_text = NumberedText(data_str)
    if language:
        self.language = language
    if metadata:
        self.merge_metadata(metadata)
    if process_metadata:    
        self._metadata.add_process_info(process_metadata)
    if sections:
        self._sections = sections

    return self
update_metadata(**kwargs)

Update metadata with new key-value pairs.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
354
355
356
357
def update_metadata(self, **kwargs) -> None:
    """Update metadata with new key-value pairs."""
    new_metadata = Metadata(kwargs)
    self.merge_metadata(new_metadata)
validate_sections()

Basic validation of section integrity.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def validate_sections(self) -> None:
    """Basic validation of section integrity."""
    if not self._sections:
        raise ValueError("No sections set.")

    # Check section ordering and bounds
    for i, section in enumerate(self._sections):
        if section.section_range.start < 1:
            logger.warning(f"Section {i}: start line must be >= 1")
        if section.section_range.start > self.num_text.size:
            logger.warning(f"Section {i}: start line exceeds text length")
        if i > 0 and \
            section.section_range.start <= self._sections[i-1].section_range.start:
            logger.warning(f"Section {i}: non-sequential start line")

TextObjectInfo

Bases: BaseModel

Serializable information about a text and its sections.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
141
142
143
144
145
146
147
148
149
150
151
152
153
class TextObjectInfo(BaseModel):
    """Serializable information about a text and its sections."""
    source_file: Optional[Path] = None  # Original text file path
    language: str
    sections: List[SectionObject]
    metadata: Metadata

    def model_post_init(self, __context: Any) -> None:
        """Ensure metadata is always a Metadata instance after initialization."""
        if isinstance(self.metadata, dict):
            self.metadata = Metadata(self.metadata)
        elif not isinstance(self.metadata, Metadata):
            raise ValueError(f"Unexpected type for metadata: {type(self.metadata)}")
language instance-attribute
metadata instance-attribute
sections instance-attribute
source_file = None class-attribute instance-attribute
model_post_init(__context)

Ensure metadata is always a Metadata instance after initialization.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
148
149
150
151
152
153
def model_post_init(self, __context: Any) -> None:
    """Ensure metadata is always a Metadata instance after initialization."""
    if isinstance(self.metadata, dict):
        self.metadata = Metadata(self.metadata)
    elif not isinstance(self.metadata, Metadata):
        raise ValueError(f"Unexpected type for metadata: {type(self.metadata)}")

__dir__()

Source code in src/tnh_scholar/ai_text_processing/__init__.py
76
77
def __dir__() -> list[str]:
    return sorted(__all__)

__getattr__(name)

Source code in src/tnh_scholar/ai_text_processing/__init__.py
65
66
67
68
69
70
71
72
73
def __getattr__(name: str) -> Any:
    module_path = _LAZY_ATTRS.get(name)
    if not module_path:
        raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

    module = import_module(module_path)
    value = getattr(module, name)
    globals()[name] = value
    return value

find_sections(text, source_language=None, section_pattern=None, section_model=None, max_tokens=DEFAULT_SECTION_RESULT_MAX_SIZE, section_count=None, review_count=DEFAULT_REVIEW_COUNT, template_dict=None)

High-level function for generating text sections.

Parameters:

Name Type Description Default
text TextObject

Input text

required
source_language Optional[str]

ISO 639-1 language code

None
pattern

Optional custom pattern (uses default if None)

required
model

Optional model identifier

required
max_tokens int

Maximum tokens for response

DEFAULT_SECTION_RESULT_MAX_SIZE
section_count Optional[int]

Target number of sections

None
review_count int

Number of review passes

DEFAULT_REVIEW_COUNT
template_dict Optional[Dict[str, str]]

Optional additional template variables

None

Returns:

Type Description
TextObject

TextObject containing section breakdown

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def find_sections(
    text: TextObject,
    source_language: Optional[str] = None,
    section_pattern: Optional[Prompt] = None,
    section_model: Optional[str] = None,
    max_tokens: int = DEFAULT_SECTION_RESULT_MAX_SIZE,
    section_count: Optional[int] = None,
    review_count: int = DEFAULT_REVIEW_COUNT,
    template_dict: Optional[Dict[str, str]] = None,
) -> TextObject:
    """
    High-level function for generating text sections.

    Args:
        text: Input text
        source_language: ISO 639-1 language code
        pattern: Optional custom pattern (uses default if None)
        model: Optional model identifier
        max_tokens: Maximum tokens for response
        section_count: Target number of sections
        review_count: Number of review passes
        template_dict: Optional additional template variables

    Returns:
        TextObject containing section breakdown
    """
    if section_pattern is None:
        section_pattern = get_pattern(DEFAULT_SECTION_PATTERN)
        logger.debug(f"Using default section pattern: {DEFAULT_SECTION_PATTERN}.")

    section_scanner = OpenAIProcessor(model=section_model, max_tokens=max_tokens)
    parser = SectionParser(
        section_scanner=section_scanner,
        section_pattern=section_pattern,
        review_count=review_count,
    )

    process_metadata = ProcessMetadata(
            step="find_sections",
            processor="SectionProcessor", 
            source_language=source_language,
            pattern=section_pattern.name,
            model=section_model,
            section_count=section_count,
            review_count=review_count,
            template_dict=template_dict,
        )

    result_text = parser.find_sections(
        text,
        section_count_target=section_count,
        template_dict=template_dict,
    )
    result_text.transform(process_metadata=process_metadata)
    return result_text

get_pattern(name)

Get a pattern by name using the singleton PatternManager.

This is a more efficient version that reuses a single PatternManager instance.

Parameters:

Name Type Description Default
name str

Name of the pattern to load

required

Returns:

Type Description
Prompt

The loaded pattern

Raises:

Type Description
ValueError

If pattern name is invalid

FileNotFoundError

If pattern file doesn't exist

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
def get_pattern(name: str) -> Prompt:
    """
    Get a pattern by name using the singleton PatternManager.

    This is a more efficient version that reuses a single PatternManager instance.

    Args:
        name: Name of the pattern to load

    Returns:
        The loaded pattern

    Raises:
        ValueError: If pattern name is invalid
        FileNotFoundError: If pattern file doesn't exist
    """
    return LocalPromptManager().get_prompt(name)

openai_process_text(text_input, process_instructions, model=None, response_format=None, batch=False, max_tokens=0)

postprocessing a transcription.

Source code in src/tnh_scholar/ai_text_processing/openai_process_interface.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def openai_process_text(
    text_input: str,
    process_instructions: str,
    model: Optional[str] = None,
    response_format: Optional[Type[BaseModel]] = None,
    batch: bool = False,
    max_tokens: int = 0,
) -> Union[BaseModel, str]:
    """postprocessing a transcription."""

    user_prompts = [text_input]
    system_message = process_instructions

    logger.debug(f"OpenAI Process Text with process instructions:\n{system_message}")
    if max_tokens == 0:
        tokens = token_count(text_input)
        max_tokens = tokens + TOKEN_BUFFER

    model_name = model or "default"

    logger.info(
        f"Open AI Text Processing{' as batch process' if batch else ''} "
        f"with model '{model_name}' initiated.\n"
        f"Requesting a maximum of {max_tokens} tokens."
    )

    if batch:
        return _run_batch_process_text(
            user_prompts, system_message, max_tokens, model_name, response_format
        )

    completion_result = simple_completion(
        system_message=system_message,
        user_message=text_input,
        model=model,
        max_tokens=max_tokens,
        response_model=response_format,
    )
    logger.info("Processing completed.")
    return completion_result

process_text(text, pattern, source_language=None, model=None, template_dict=None)

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
def process_text(
    text: TextObject,
    pattern: Prompt,
    source_language: Optional[str] = None,
    model: Optional[str] = None,
    template_dict: Optional[Dict] = None,
) -> TextObject:

    if not model:
        model = DEFAULT_OPENAI_MODEL

    processor = GeneralProcessor(
        processor=OpenAIProcessor(model),
        source_language=source_language,
        pattern=pattern,
    )

    process_metadata = ProcessMetadata(
            step="process_text",
            processor="GeneralProcessor",
            pattern=pattern.name,
            model=model,
            template_dict=template_dict,
        )

    result = processor.process_text(
        text, template_dict=template_dict
    )
    text.transform(data_str=result, process_metadata=process_metadata)
    return text

process_text_by_paragraphs(text, template_dict, pattern=None, model=None)

High-level function for processing text paragraphs, yielding ProcessedSection objects. Assumes paragraphs are separated by newlines. Uses DEFAULT_XML_FORMAT_PATTERN as default pattern for text processing.

Parameters:

Name Type Description Default
text TextObject

TextObject to process

required
template_dict Dict[str, str]

Dictionary for template substitution

required
pattern Optional[Prompt]

Pattern object containing processing instructions

None
model Optional[str]

Optional model identifier for processor

None

Returns:

Type Description
None

Generator for ProcessedSection objects (one per paragraph)

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
def process_text_by_paragraphs(
    text: TextObject,
    template_dict: Dict[str, str],
    pattern: Optional[Prompt] = None,
    model: Optional[str] = None,
) -> Generator[ProcessedSection, None, None]:
    """
    High-level function for processing text paragraphs, yielding ProcessedSection objects.
    Assumes paragraphs are separated by newlines.
    Uses DEFAULT_XML_FORMAT_PATTERN as default pattern for text processing.

    Args:
        text: TextObject to process
        template_dict: Dictionary for template substitution
        pattern: Pattern object containing processing instructions
        model: Optional model identifier for processor

    Returns:
        Generator for ProcessedSection objects (one per paragraph)
    """
    processor = OpenAIProcessor(model)

    if not pattern:
        pattern = get_pattern(DEFAULT_PARAGRAPH_FORMAT_PATTERN)

    section_processor = SectionProcessor(processor, pattern, template_dict)

    process_metadata = ProcessMetadata(
        step="process_text_by_paragraphs",
        processor="SectionProcessor",
        pattern=pattern.name,
        model=model,
        template_dict=template_dict,
    )

    result = section_processor.process_paragraphs(text)

    text.transform(process_metadata=process_metadata)

    return result

process_text_by_sections(text_object, template_dict, pattern, model=None)

High-level function for processing text sections with configurable output handling.

Parameters:

Name Type Description Default
transcript

Text to process

required
text_object TextObject

Object containing section definitions

required
pattern Prompt

Pattern object containing processing instructions

required
template_dict Dict

Dictionary for template substitution

required
model Optional[str]

Optional model identifier for processor

None

Returns:

Type Description
None

Generator for ProcessedSections

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
def process_text_by_sections(
    text_object: TextObject,
    template_dict: Dict,
    pattern: Prompt,
    model: Optional[str] = None,
) -> Generator[ProcessedSection, None, None]:
    """
    High-level function for processing text sections with configurable output handling.

    Args:
        transcript: Text to process
        text_object: Object containing section definitions
        pattern: Pattern object containing processing instructions
        template_dict: Dictionary for template substitution
        model: Optional model identifier for processor

    Returns:
        Generator for ProcessedSections
    """
    processor = OpenAIProcessor(model)

    section_processor = SectionProcessor(processor, pattern, template_dict)

    process_metadata = ProcessMetadata(
            step="process_text_by_sections",
            processor="SectionProcessor",
            pattern=pattern.name,
            model=model,
            template_dict=template_dict,
        )
    result = section_processor.process_sections(text_object)

    text_object.transform(process_metadata=process_metadata)

    return result

translate_text_by_lines(text, source_language=None, target_language=DEFAULT_TARGET_LANGUAGE, pattern=None, model=None, style=None, segment_size=None, context_lines=None, review_count=None, template_dict=None)

Source code in src/tnh_scholar/ai_text_processing/line_translator.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
def translate_text_by_lines(
    text: TextObject,
    source_language: Optional[str] = None,
    target_language: str = DEFAULT_TARGET_LANGUAGE,
    pattern: Optional[Prompt] = None,
    model: Optional[str] = None,
    style: Optional[str] = None,
    segment_size: Optional[int] = None,
    context_lines: Optional[int] = None,
    review_count: Optional[int] = None,
    template_dict: Optional[Dict] = None,
) -> TextObject:

    if source_language is None:
        source_language = text.language

    if pattern is None:
        pattern = get_pattern(DEFAULT_TRANSLATION_PATTERN)

    processor = OpenAIProcessor(model)

    translator = LineTranslator(
        processor=processor,
        pattern=pattern,
        style=style or DEFAULT_TRANSLATE_STYLE,
        context_lines=context_lines or DEFAULT_TRANSLATE_CONTEXT_LINES,
        review_count=review_count or DEFAULT_REVIEW_COUNT,
    )

    process_metadata = ProcessMetadata(
            step="translation",
            processor="LineTranslator",
            model=processor.model,
            source_language=source_language,
            target_language=target_language,
            segment_size=segment_size,
            context_lines=translator.context_lines,
            review_count=translator.review_count,
            style=translator.style,
            template_dict=template_dict,
        )

    text = translator.translate_text(
        text,
        source_language=source_language,
        target_language=target_language,
        segment_size=segment_size,
        template_dict=template_dict,
    )
    return text.transform(process_metadata=process_metadata)

ai_text_processing

DEFAULT_MIN_SECTION_COUNT = 3 module-attribute
DEFAULT_OPENAI_MODEL = 'gpt-4o' module-attribute
DEFAULT_PARAGRAPH_FORMAT_PATTERN = 'default_xml_paragraph_format' module-attribute
DEFAULT_PUNCTUATE_MODEL = 'gpt-4o' module-attribute
DEFAULT_PUNCTUATE_PATTERN = 'default_punctuate' module-attribute
DEFAULT_PUNCTUATE_STYLE = 'APA' module-attribute
DEFAULT_REVIEW_COUNT = 5 module-attribute
DEFAULT_SECTION_PATTERN = 'default_section' module-attribute
DEFAULT_SECTION_RANGE_VAR = 2 module-attribute
DEFAULT_SECTION_RESULT_MAX_SIZE = 4000 module-attribute
DEFAULT_SECTION_TOKEN_SIZE = 650 module-attribute
DEFAULT_XML_FORMAT_PATTERN = 'default_xml_format' module-attribute
SECTION_SEGMENT_SIZE_WARNING_LIMIT = 5 module-attribute
logger = get_child_logger(__name__) module-attribute
GeneralProcessor
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
class GeneralProcessor:
    def __init__(
        self,
        processor: TextProcessor,
        pattern: Prompt,
        source_language: Optional[str] = None,
        review_count: int = DEFAULT_REVIEW_COUNT,
    ):
        """
        Initialize general processor.

        Args:
            text_punctuator: Implementation of TextProcessor
            pattern: Pattern object containing processing instructions
            section_count: Target number of sections
            review_count: Number of review passes
        """

        self.source_language = source_language
        self.processor = processor
        self.pattern = pattern
        self.review_count = review_count

    def process_text(
        self,
        text: TextObject,
        template_dict: Optional[Dict] = None,
    ) -> str:
        """
        process a text based on a pattern and source language.
        """

        source_language = get_language_from_code(text.language)

        template_values = {
            "metadata": text.metadata_str,
            "source_language": source_language,
            "review_count": self.review_count,
        }

        if template_dict:
            template_values |= template_dict

        logger.info("Processing text...")
        instructions = self.pattern.apply_template(template_values)

        logger.debug(f"Process instructions:\n{instructions}")

        result = self.processor.process_text(text.content, instructions)

        logger.info("Processing completed.")

        # normalize newline spacing to two newline between lines and return
        # commented out to allow pattern to dictate newlines:
        # return normalize_newlines(text)
        return result
pattern = pattern instance-attribute
processor = processor instance-attribute
review_count = review_count instance-attribute
source_language = source_language instance-attribute
__init__(processor, pattern, source_language=None, review_count=DEFAULT_REVIEW_COUNT)

Initialize general processor.

Parameters:

Name Type Description Default
text_punctuator

Implementation of TextProcessor

required
pattern Prompt

Pattern object containing processing instructions

required
section_count

Target number of sections

required
review_count int

Number of review passes

DEFAULT_REVIEW_COUNT
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
def __init__(
    self,
    processor: TextProcessor,
    pattern: Prompt,
    source_language: Optional[str] = None,
    review_count: int = DEFAULT_REVIEW_COUNT,
):
    """
    Initialize general processor.

    Args:
        text_punctuator: Implementation of TextProcessor
        pattern: Pattern object containing processing instructions
        section_count: Target number of sections
        review_count: Number of review passes
    """

    self.source_language = source_language
    self.processor = processor
    self.pattern = pattern
    self.review_count = review_count
process_text(text, template_dict=None)

process a text based on a pattern and source language.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
def process_text(
    self,
    text: TextObject,
    template_dict: Optional[Dict] = None,
) -> str:
    """
    process a text based on a pattern and source language.
    """

    source_language = get_language_from_code(text.language)

    template_values = {
        "metadata": text.metadata_str,
        "source_language": source_language,
        "review_count": self.review_count,
    }

    if template_dict:
        template_values |= template_dict

    logger.info("Processing text...")
    instructions = self.pattern.apply_template(template_values)

    logger.debug(f"Process instructions:\n{instructions}")

    result = self.processor.process_text(text.content, instructions)

    logger.info("Processing completed.")

    # normalize newline spacing to two newline between lines and return
    # commented out to allow pattern to dictate newlines:
    # return normalize_newlines(text)
    return result
OpenAIProcessor

Bases: TextProcessor

OpenAI-based text processor implementation.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class OpenAIProcessor(TextProcessor):
    """OpenAI-based text processor implementation."""
    def __init__(self, model: Optional[str] = None, max_tokens: int = 0):
        if not model:
            model = DEFAULT_OPENAI_MODEL
        self.model = model
        self.max_tokens = max_tokens

    def process_text(
        self,
        input_str: str,
        instructions: str,
        response_format: Optional[Type[BaseModel]] = None,
        max_tokens: int = 0,
        **kwargs,
    ) -> ProcessorResult:
        """Process text using OpenAI API with optional structured output."""

        if max_tokens == 0 and self.max_tokens > 0:
            max_tokens = self.max_tokens

        return openai_process_text(
            input_str,
            instructions,
            model=self.model,
            max_tokens=max_tokens,
            response_format=response_format,
            **kwargs,
        )
max_tokens = max_tokens instance-attribute
model = model instance-attribute
__init__(model=None, max_tokens=0)
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
78
79
80
81
82
def __init__(self, model: Optional[str] = None, max_tokens: int = 0):
    if not model:
        model = DEFAULT_OPENAI_MODEL
    self.model = model
    self.max_tokens = max_tokens
process_text(input_str, instructions, response_format=None, max_tokens=0, **kwargs)

Process text using OpenAI API with optional structured output.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def process_text(
    self,
    input_str: str,
    instructions: str,
    response_format: Optional[Type[BaseModel]] = None,
    max_tokens: int = 0,
    **kwargs,
) -> ProcessorResult:
    """Process text using OpenAI API with optional structured output."""

    if max_tokens == 0 and self.max_tokens > 0:
        max_tokens = self.max_tokens

    return openai_process_text(
        input_str,
        instructions,
        model=self.model,
        max_tokens=max_tokens,
        response_format=response_format,
        **kwargs,
    )
ProcessedSection dataclass

Represents a processed section of text with its metadata.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
44
45
46
47
48
49
50
@dataclass
class ProcessedSection:
    """Represents a processed section of text with its metadata."""
    title: str
    original_str: str
    processed_str: str
    metadata: Dict = field(default_factory=dict)
metadata = field(default_factory=dict) class-attribute instance-attribute
original_str instance-attribute
processed_str instance-attribute
title instance-attribute
__init__(title, original_str, processed_str, metadata=dict())
SectionParser

Generates structured section breakdowns of text content.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
class SectionParser:
    """Generates structured section breakdowns of text content."""

    def __init__(
        self,
        section_scanner: TextProcessor,
        section_pattern: Prompt,
        review_count: int = DEFAULT_REVIEW_COUNT,
    ):
        """
        Initialize section generator.

        Args:
            processor: Implementation of TextProcessor
            pattern: Pattern object containing section generation instructions
            max_tokens: Maximum tokens for response
            section_count: Target number of sections
            review_count: Number of review passes
        """
        self.section_scanner = section_scanner
        self.section_pattern = section_pattern
        self.review_count = review_count

    def find_sections(
        self,
        text: TextObject,
        section_count_target: Optional[int] = None,
        segment_size_target: Optional[int] = None,
        template_dict: Optional[Dict[str, str]] = None,
    ) -> TextObject:
        """
        Generate section breakdown of input text. The text must be split up by newlines.

        Args:
            text: Input TextObject to process
            section_count_target: the target for the number of sections to find
            segment_size_target: the target for the number of lines per section
                (if section_count_target is specified, 
                this value will be set to generate correct segments)
            template_dict: Optional additional template variables

        Returns:
            TextObject containing section breakdown
        """

        # Prepare numbered text, each line is numbered
        num_text = text.num_text

        if num_text.size < SECTION_SEGMENT_SIZE_WARNING_LIMIT:
            logger.warning(
                f"find_sections: Text has only {num_text.size} lines. "
                "This may lead to unexpected sectioning results."
            )

        # Get language if not specified
        source_language = get_language_from_code(text.language)

        # determine section count if not specified
        if not section_count_target:
            segment_size_target, section_count_target = self._get_section_count_info(
                text.content
            )
        elif not segment_size_target:
            segment_size_target = round(num_text.size / section_count_target)

        section_count_range = self._get_section_count_range(section_count_target)

        current_metadata = text.metadata

        # Prepare template variables
        template_values = {
            "metadata": current_metadata.to_yaml(),
            "source_language": source_language,
            "section_count": section_count_range,
            "line_count": segment_size_target,
            "review_count": self.review_count,
        }

        if template_dict:
            template_values |= template_dict

        # Get and apply processing instructions
        instructions = self.section_pattern.apply_template(template_values)
        logger.debug(f"Finding sections with pattern instructions:\n {instructions}")

        logger.info(
            f"Finding sections for {source_language} text "
            f"(target sections: {section_count_target})"
        )

        # Process text with structured output
        result = self.section_scanner.process_text(
            num_text.numbered_content, instructions, response_format=AIResponse
        )

        ai_response = cast(AIResponse, result)
        text_result = TextObject.from_response(ai_response, current_metadata, num_text)

        logger.info(f"Generated {text_result.section_count} sections.")

        return text_result

    def _get_section_count_info(self, text: str) -> Tuple[int, int]:
        num_text = NumberedText(text)
        segment_size = _calculate_segment_size(num_text, DEFAULT_SECTION_TOKEN_SIZE)
        section_count_target = round(num_text.size / segment_size)
        return segment_size, section_count_target

    def _get_section_count_range(
        self,
        section_count_target: int,
        section_range_var: int = DEFAULT_SECTION_RANGE_VAR,
    ) -> str:
        low = max(1, section_count_target - section_range_var)
        high = section_count_target + section_range_var
        return f"{low}-{high}"
review_count = review_count instance-attribute
section_pattern = section_pattern instance-attribute
section_scanner = section_scanner instance-attribute
__init__(section_scanner, section_pattern, review_count=DEFAULT_REVIEW_COUNT)

Initialize section generator.

Parameters:

Name Type Description Default
processor

Implementation of TextProcessor

required
pattern

Pattern object containing section generation instructions

required
max_tokens

Maximum tokens for response

required
section_count

Target number of sections

required
review_count int

Number of review passes

DEFAULT_REVIEW_COUNT
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def __init__(
    self,
    section_scanner: TextProcessor,
    section_pattern: Prompt,
    review_count: int = DEFAULT_REVIEW_COUNT,
):
    """
    Initialize section generator.

    Args:
        processor: Implementation of TextProcessor
        pattern: Pattern object containing section generation instructions
        max_tokens: Maximum tokens for response
        section_count: Target number of sections
        review_count: Number of review passes
    """
    self.section_scanner = section_scanner
    self.section_pattern = section_pattern
    self.review_count = review_count
find_sections(text, section_count_target=None, segment_size_target=None, template_dict=None)

Generate section breakdown of input text. The text must be split up by newlines.

Parameters:

Name Type Description Default
text TextObject

Input TextObject to process

required
section_count_target Optional[int]

the target for the number of sections to find

None
segment_size_target Optional[int]

the target for the number of lines per section (if section_count_target is specified, this value will be set to generate correct segments)

None
template_dict Optional[Dict[str, str]]

Optional additional template variables

None

Returns:

Type Description
TextObject

TextObject containing section breakdown

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def find_sections(
    self,
    text: TextObject,
    section_count_target: Optional[int] = None,
    segment_size_target: Optional[int] = None,
    template_dict: Optional[Dict[str, str]] = None,
) -> TextObject:
    """
    Generate section breakdown of input text. The text must be split up by newlines.

    Args:
        text: Input TextObject to process
        section_count_target: the target for the number of sections to find
        segment_size_target: the target for the number of lines per section
            (if section_count_target is specified, 
            this value will be set to generate correct segments)
        template_dict: Optional additional template variables

    Returns:
        TextObject containing section breakdown
    """

    # Prepare numbered text, each line is numbered
    num_text = text.num_text

    if num_text.size < SECTION_SEGMENT_SIZE_WARNING_LIMIT:
        logger.warning(
            f"find_sections: Text has only {num_text.size} lines. "
            "This may lead to unexpected sectioning results."
        )

    # Get language if not specified
    source_language = get_language_from_code(text.language)

    # determine section count if not specified
    if not section_count_target:
        segment_size_target, section_count_target = self._get_section_count_info(
            text.content
        )
    elif not segment_size_target:
        segment_size_target = round(num_text.size / section_count_target)

    section_count_range = self._get_section_count_range(section_count_target)

    current_metadata = text.metadata

    # Prepare template variables
    template_values = {
        "metadata": current_metadata.to_yaml(),
        "source_language": source_language,
        "section_count": section_count_range,
        "line_count": segment_size_target,
        "review_count": self.review_count,
    }

    if template_dict:
        template_values |= template_dict

    # Get and apply processing instructions
    instructions = self.section_pattern.apply_template(template_values)
    logger.debug(f"Finding sections with pattern instructions:\n {instructions}")

    logger.info(
        f"Finding sections for {source_language} text "
        f"(target sections: {section_count_target})"
    )

    # Process text with structured output
    result = self.section_scanner.process_text(
        num_text.numbered_content, instructions, response_format=AIResponse
    )

    ai_response = cast(AIResponse, result)
    text_result = TextObject.from_response(ai_response, current_metadata, num_text)

    logger.info(f"Generated {text_result.section_count} sections.")

    return text_result
SectionProcessor

Handles section-based XML text processing with configurable output handling.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
class SectionProcessor:
    """Handles section-based XML text processing with configurable output handling."""

    def __init__(
        self,
        processor: TextProcessor,
        pattern: Prompt,
        template_dict: Dict,
        wrap_in_document: bool = True,
    ):
        """
        Initialize the XML section processor.

        Args:
            processor: Implementation of TextProcessor to use
            pattern: Pattern object containing processing instructions
            template_dict: Dictionary for template substitution
            wrap_in_document: Whether to wrap output in <document> tags
        """
        self.processor = processor
        self.pattern = pattern
        self.template_dict = template_dict
        self.wrap_in_document = wrap_in_document

    def process_sections(
        self,
        text_object: TextObject,
    ) -> Generator[ProcessedSection, None, None]:
        """
        Process transcript sections and yield results one section at a time.

        Args:
            transcript: Text to process
            text_object: Object containing section definitions

        Yields:
            ProcessedSection: One processed section at a time, containing:
                - title: Section title (English or original language)
                - original_text: Raw text segment
                - processed_text: Processed text content
                - start_line: Starting line number
        """
        # numbered_transcript = NumberedText(transcript) 
        # transcript is now stored in the TextObject
        sections = text_object.sections

        logger.info(
            f"Processing {len(sections)} sections with pattern: {self.pattern.name}"
        )

        for section_entry in text_object:
            logger.info(f"Processing section {section_entry.number} "
                        f"'{section_entry.title}':")

            # Get text segment for section
            text_segment = section_entry.content

            # Prepare template variables
            template_values = {
                "metadata": text_object.metadata.to_yaml(),
                "section_title": section_entry.title,
                "source_language": get_language_from_code(text_object.language),
                "review_count": DEFAULT_REVIEW_COUNT,
            }

            if self.template_dict:
                template_values |= self.template_dict

            # Get and apply processing instructions
            instructions = self.pattern.apply_template(template_values)
            processed_str = self.processor.process_text(text_segment, instructions)

            yield ProcessedSection(
                title=section_entry.title,
                original_str=text_segment,
                processed_str=processed_str,
            )

    def process_paragraphs(
        self,
        text: TextObject,
    ) -> Generator[ProcessedSection, None, None]:
        """
        Process transcript by paragraphs (as sections), yielding ProcessedSection objects.
        Paragraphs are assumed to be given as newline separated.

        Args:
            text: TextObject to process

        Yields:
            ProcessedSection: One processed paragraph at a time, containing:
                - title: Paragraph number (e.g., 'Paragraph 1')
                - original_str: Raw paragraph text
                - processed_str: Processed paragraph text
                - metadata: Optional metadata dict
        """
        num_text = text.num_text

        logger.info(f"Processing lines as paragraphs with pattern: {self.pattern.name}")

        for i, line in num_text:
            # If line is empty or whitespace, continue
            if not line.strip():
                continue

            instructions = self.pattern.apply_template(self.template_dict)

            if i <= 1:
                logger.debug(f"Process instructions (first paragraph):\n{instructions}")

            processed_str = self.processor.process_text(line, instructions)
            yield ProcessedSection(
                title=f"Paragraph {i}",
                original_str=line,
                processed_str=processed_str,
                metadata={"paragraph_number": i}
            )
pattern = pattern instance-attribute
processor = processor instance-attribute
template_dict = template_dict instance-attribute
wrap_in_document = wrap_in_document instance-attribute
__init__(processor, pattern, template_dict, wrap_in_document=True)

Initialize the XML section processor.

Parameters:

Name Type Description Default
processor TextProcessor

Implementation of TextProcessor to use

required
pattern Prompt

Pattern object containing processing instructions

required
template_dict Dict

Dictionary for template substitution

required
wrap_in_document bool

Whether to wrap output in tags

True
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def __init__(
    self,
    processor: TextProcessor,
    pattern: Prompt,
    template_dict: Dict,
    wrap_in_document: bool = True,
):
    """
    Initialize the XML section processor.

    Args:
        processor: Implementation of TextProcessor to use
        pattern: Pattern object containing processing instructions
        template_dict: Dictionary for template substitution
        wrap_in_document: Whether to wrap output in <document> tags
    """
    self.processor = processor
    self.pattern = pattern
    self.template_dict = template_dict
    self.wrap_in_document = wrap_in_document
process_paragraphs(text)

Process transcript by paragraphs (as sections), yielding ProcessedSection objects. Paragraphs are assumed to be given as newline separated.

Parameters:

Name Type Description Default
text TextObject

TextObject to process

required

Yields:

Name Type Description
ProcessedSection ProcessedSection

One processed paragraph at a time, containing: - title: Paragraph number (e.g., 'Paragraph 1') - original_str: Raw paragraph text - processed_str: Processed paragraph text - metadata: Optional metadata dict

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def process_paragraphs(
    self,
    text: TextObject,
) -> Generator[ProcessedSection, None, None]:
    """
    Process transcript by paragraphs (as sections), yielding ProcessedSection objects.
    Paragraphs are assumed to be given as newline separated.

    Args:
        text: TextObject to process

    Yields:
        ProcessedSection: One processed paragraph at a time, containing:
            - title: Paragraph number (e.g., 'Paragraph 1')
            - original_str: Raw paragraph text
            - processed_str: Processed paragraph text
            - metadata: Optional metadata dict
    """
    num_text = text.num_text

    logger.info(f"Processing lines as paragraphs with pattern: {self.pattern.name}")

    for i, line in num_text:
        # If line is empty or whitespace, continue
        if not line.strip():
            continue

        instructions = self.pattern.apply_template(self.template_dict)

        if i <= 1:
            logger.debug(f"Process instructions (first paragraph):\n{instructions}")

        processed_str = self.processor.process_text(line, instructions)
        yield ProcessedSection(
            title=f"Paragraph {i}",
            original_str=line,
            processed_str=processed_str,
            metadata={"paragraph_number": i}
        )
process_sections(text_object)

Process transcript sections and yield results one section at a time.

Parameters:

Name Type Description Default
transcript

Text to process

required
text_object TextObject

Object containing section definitions

required

Yields:

Name Type Description
ProcessedSection ProcessedSection

One processed section at a time, containing: - title: Section title (English or original language) - original_text: Raw text segment - processed_text: Processed text content - start_line: Starting line number

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
def process_sections(
    self,
    text_object: TextObject,
) -> Generator[ProcessedSection, None, None]:
    """
    Process transcript sections and yield results one section at a time.

    Args:
        transcript: Text to process
        text_object: Object containing section definitions

    Yields:
        ProcessedSection: One processed section at a time, containing:
            - title: Section title (English or original language)
            - original_text: Raw text segment
            - processed_text: Processed text content
            - start_line: Starting line number
    """
    # numbered_transcript = NumberedText(transcript) 
    # transcript is now stored in the TextObject
    sections = text_object.sections

    logger.info(
        f"Processing {len(sections)} sections with pattern: {self.pattern.name}"
    )

    for section_entry in text_object:
        logger.info(f"Processing section {section_entry.number} "
                    f"'{section_entry.title}':")

        # Get text segment for section
        text_segment = section_entry.content

        # Prepare template variables
        template_values = {
            "metadata": text_object.metadata.to_yaml(),
            "section_title": section_entry.title,
            "source_language": get_language_from_code(text_object.language),
            "review_count": DEFAULT_REVIEW_COUNT,
        }

        if self.template_dict:
            template_values |= self.template_dict

        # Get and apply processing instructions
        instructions = self.pattern.apply_template(template_values)
        processed_str = self.processor.process_text(text_segment, instructions)

        yield ProcessedSection(
            title=section_entry.title,
            original_str=text_segment,
            processed_str=processed_str,
        )
TextProcessor

Bases: ABC

Abstract base class for text processors that can return Pydantic objects.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class TextProcessor(ABC):
    """Abstract base class for text processors that can return Pydantic objects."""
    @abstractmethod
    def process_text(
        self,
        input_str: str,
        instructions: str,
        response_format: Optional[Type[BaseModel]] = None,
        **kwargs,
    ) -> ProcessorResult:
        """
        Process text according to instructions.

        Args:
            text: Input text to process
            instructions: Processing instructions
            response_object: Optional Pydantic class for structured output
            **kwargs: Additional processing parameters

        Returns:
            Either string or Pydantic model instance based on response_model
        """
        pass
process_text(input_str, instructions, response_format=None, **kwargs) abstractmethod

Process text according to instructions.

Parameters:

Name Type Description Default
text

Input text to process

required
instructions str

Processing instructions

required
response_object

Optional Pydantic class for structured output

required
**kwargs

Additional processing parameters

{}

Returns:

Type Description
ProcessorResult

Either string or Pydantic model instance based on response_model

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@abstractmethod
def process_text(
    self,
    input_str: str,
    instructions: str,
    response_format: Optional[Type[BaseModel]] = None,
    **kwargs,
) -> ProcessorResult:
    """
    Process text according to instructions.

    Args:
        text: Input text to process
        instructions: Processing instructions
        response_object: Optional Pydantic class for structured output
        **kwargs: Additional processing parameters

    Returns:
        Either string or Pydantic model instance based on response_model
    """
    pass
find_sections(text, source_language=None, section_pattern=None, section_model=None, max_tokens=DEFAULT_SECTION_RESULT_MAX_SIZE, section_count=None, review_count=DEFAULT_REVIEW_COUNT, template_dict=None)

High-level function for generating text sections.

Parameters:

Name Type Description Default
text TextObject

Input text

required
source_language Optional[str]

ISO 639-1 language code

None
pattern

Optional custom pattern (uses default if None)

required
model

Optional model identifier

required
max_tokens int

Maximum tokens for response

DEFAULT_SECTION_RESULT_MAX_SIZE
section_count Optional[int]

Target number of sections

None
review_count int

Number of review passes

DEFAULT_REVIEW_COUNT
template_dict Optional[Dict[str, str]]

Optional additional template variables

None

Returns:

Type Description
TextObject

TextObject containing section breakdown

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def find_sections(
    text: TextObject,
    source_language: Optional[str] = None,
    section_pattern: Optional[Prompt] = None,
    section_model: Optional[str] = None,
    max_tokens: int = DEFAULT_SECTION_RESULT_MAX_SIZE,
    section_count: Optional[int] = None,
    review_count: int = DEFAULT_REVIEW_COUNT,
    template_dict: Optional[Dict[str, str]] = None,
) -> TextObject:
    """
    High-level function for generating text sections.

    Args:
        text: Input text
        source_language: ISO 639-1 language code
        pattern: Optional custom pattern (uses default if None)
        model: Optional model identifier
        max_tokens: Maximum tokens for response
        section_count: Target number of sections
        review_count: Number of review passes
        template_dict: Optional additional template variables

    Returns:
        TextObject containing section breakdown
    """
    if section_pattern is None:
        section_pattern = get_pattern(DEFAULT_SECTION_PATTERN)
        logger.debug(f"Using default section pattern: {DEFAULT_SECTION_PATTERN}.")

    section_scanner = OpenAIProcessor(model=section_model, max_tokens=max_tokens)
    parser = SectionParser(
        section_scanner=section_scanner,
        section_pattern=section_pattern,
        review_count=review_count,
    )

    process_metadata = ProcessMetadata(
            step="find_sections",
            processor="SectionProcessor", 
            source_language=source_language,
            pattern=section_pattern.name,
            model=section_model,
            section_count=section_count,
            review_count=review_count,
            template_dict=template_dict,
        )

    result_text = parser.find_sections(
        text,
        section_count_target=section_count,
        template_dict=template_dict,
    )
    result_text.transform(process_metadata=process_metadata)
    return result_text
get_pattern(name)

Get a pattern by name using the singleton PatternManager.

This is a more efficient version that reuses a single PatternManager instance.

Parameters:

Name Type Description Default
name str

Name of the pattern to load

required

Returns:

Type Description
Prompt

The loaded pattern

Raises:

Type Description
ValueError

If pattern name is invalid

FileNotFoundError

If pattern file doesn't exist

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
def get_pattern(name: str) -> Prompt:
    """
    Get a pattern by name using the singleton PatternManager.

    This is a more efficient version that reuses a single PatternManager instance.

    Args:
        name: Name of the pattern to load

    Returns:
        The loaded pattern

    Raises:
        ValueError: If pattern name is invalid
        FileNotFoundError: If pattern file doesn't exist
    """
    return LocalPromptManager().get_prompt(name)
process_text(text, pattern, source_language=None, model=None, template_dict=None)
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
def process_text(
    text: TextObject,
    pattern: Prompt,
    source_language: Optional[str] = None,
    model: Optional[str] = None,
    template_dict: Optional[Dict] = None,
) -> TextObject:

    if not model:
        model = DEFAULT_OPENAI_MODEL

    processor = GeneralProcessor(
        processor=OpenAIProcessor(model),
        source_language=source_language,
        pattern=pattern,
    )

    process_metadata = ProcessMetadata(
            step="process_text",
            processor="GeneralProcessor",
            pattern=pattern.name,
            model=model,
            template_dict=template_dict,
        )

    result = processor.process_text(
        text, template_dict=template_dict
    )
    text.transform(data_str=result, process_metadata=process_metadata)
    return text
process_text_by_paragraphs(text, template_dict, pattern=None, model=None)

High-level function for processing text paragraphs, yielding ProcessedSection objects. Assumes paragraphs are separated by newlines. Uses DEFAULT_XML_FORMAT_PATTERN as default pattern for text processing.

Parameters:

Name Type Description Default
text TextObject

TextObject to process

required
template_dict Dict[str, str]

Dictionary for template substitution

required
pattern Optional[Prompt]

Pattern object containing processing instructions

None
model Optional[str]

Optional model identifier for processor

None

Returns:

Type Description
None

Generator for ProcessedSection objects (one per paragraph)

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
def process_text_by_paragraphs(
    text: TextObject,
    template_dict: Dict[str, str],
    pattern: Optional[Prompt] = None,
    model: Optional[str] = None,
) -> Generator[ProcessedSection, None, None]:
    """
    High-level function for processing text paragraphs, yielding ProcessedSection objects.
    Assumes paragraphs are separated by newlines.
    Uses DEFAULT_XML_FORMAT_PATTERN as default pattern for text processing.

    Args:
        text: TextObject to process
        template_dict: Dictionary for template substitution
        pattern: Pattern object containing processing instructions
        model: Optional model identifier for processor

    Returns:
        Generator for ProcessedSection objects (one per paragraph)
    """
    processor = OpenAIProcessor(model)

    if not pattern:
        pattern = get_pattern(DEFAULT_PARAGRAPH_FORMAT_PATTERN)

    section_processor = SectionProcessor(processor, pattern, template_dict)

    process_metadata = ProcessMetadata(
        step="process_text_by_paragraphs",
        processor="SectionProcessor",
        pattern=pattern.name,
        model=model,
        template_dict=template_dict,
    )

    result = section_processor.process_paragraphs(text)

    text.transform(process_metadata=process_metadata)

    return result
process_text_by_sections(text_object, template_dict, pattern, model=None)

High-level function for processing text sections with configurable output handling.

Parameters:

Name Type Description Default
transcript

Text to process

required
text_object TextObject

Object containing section definitions

required
pattern Prompt

Pattern object containing processing instructions

required
template_dict Dict

Dictionary for template substitution

required
model Optional[str]

Optional model identifier for processor

None

Returns:

Type Description
None

Generator for ProcessedSections

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
def process_text_by_sections(
    text_object: TextObject,
    template_dict: Dict,
    pattern: Prompt,
    model: Optional[str] = None,
) -> Generator[ProcessedSection, None, None]:
    """
    High-level function for processing text sections with configurable output handling.

    Args:
        transcript: Text to process
        text_object: Object containing section definitions
        pattern: Pattern object containing processing instructions
        template_dict: Dictionary for template substitution
        model: Optional model identifier for processor

    Returns:
        Generator for ProcessedSections
    """
    processor = OpenAIProcessor(model)

    section_processor = SectionProcessor(processor, pattern, template_dict)

    process_metadata = ProcessMetadata(
            step="process_text_by_sections",
            processor="SectionProcessor",
            pattern=pattern.name,
            model=model,
            template_dict=template_dict,
        )
    result = section_processor.process_sections(text_object)

    text_object.transform(process_metadata=process_metadata)

    return result

general_processor

line_translator

DEFAULT_TARGET_LANGUAGE = 'en' module-attribute
DEFAULT_TRANSLATE_CONTEXT_LINES = 3 module-attribute
DEFAULT_TRANSLATE_STYLE = "'American Dharma Teaching'" module-attribute
DEFAULT_TRANSLATION_PATTERN = 'default_line_translate' module-attribute
DEFAULT_TRANSLATION_TARGET_TOKENS = 300 module-attribute
FOLLOWING_CONTEXT_MARKER = 'FOLLOWING_CONTEXT' module-attribute
MAX_RETRIES = 6 module-attribute
MIN_SEGMENT_SIZE = 4 module-attribute
PRECEDING_CONTEXT_MARKER = 'PRECEDING_CONTEXT' module-attribute
TRANSCRIPT_SEGMENT_MARKER = 'TRANSCRIPT_SEGMENT' module-attribute
logger = get_child_logger(__name__) module-attribute
LineTranslator

Translates text line by line while maintaining line numbers and context.

Source code in src/tnh_scholar/ai_text_processing/line_translator.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
class LineTranslator:
    """Translates text line by line while maintaining line numbers and context."""

    def __init__(
        self,
        processor: TextProcessor,
        pattern: Prompt,
        review_count: int = DEFAULT_REVIEW_COUNT,
        style: str = DEFAULT_TRANSLATE_STYLE,
        # Number of context lines before/after
        context_lines: int = DEFAULT_TRANSLATE_CONTEXT_LINES,  
    ):
        """
        Initialize line translator.

        Args:
            processor: Implementation of TextProcessor
            pattern: Pattern object containing translation instructions
            review_count: Number of review passes
            style: Translation style to apply
            context_lines: Number of context lines to include before/after
        """
        self.processor = processor
        self.pattern = pattern
        self.review_count = review_count
        self.style = style
        self.context_lines = context_lines

    def translate_segment(
        self,
        num_text: NumberedText,
        start_line: int,
        end_line: int,
        metadata: Metadata,
        target_language: str,
        source_language: str,
        template_dict: Optional[Dict] = None,
    ) -> str:
        """
        Translate a segment of text with context.

        Args:
            text: Full text to extract segment from
            start_line: Starting line number of segment
            end_line: Ending line number of segment
            metadata: metadata for text
            source_language: Source language code
            target_language: Target language code (default: en for English)
            template_dict: Optional additional template values

        Returns:
            Translated text segment with line numbers preserved
        """

        # Calculate context ranges
        preceding_start = max(1, start_line - self.context_lines)  # lines start on 1.
        following_end = min(num_text.end + 1, end_line + self.context_lines)

        # Extract context and segment
        preceding_context = num_text.get_numbered_segment(preceding_start, start_line)
        transcript_segment = num_text.get_numbered_segment(start_line, end_line)
        following_context = num_text.get_numbered_segment(end_line, following_end)

        # build input text
        translation_input = self._build_translation_input(
            preceding_context, transcript_segment, following_context
        )

        # Prepare template values
        template_values = {
            "source_language": get_language_from_code(source_language),
            "target_language": get_language_from_code(target_language),
            "review_count": self.review_count,
            "style": self.style,
            "metadata": metadata.to_yaml()
        }

        if template_dict:
            template_values |= template_dict

        # Get and apply translation instructions
        logger.info(f"Translating segment (lines {start_line}-{end_line})")
        translate_instructions = self.pattern.apply_template(template_values)

        if start_line <= 1:
            logger.debug(
                f"Translate instructions (first segment):\n{translate_instructions}"
            )
        logger.debug(f"Translation input:\n{translation_input}")

        return self._translate_with_retries(
            translation_input,
            translate_instructions,
            start_line,
            end_line,
            )

    def _translate_with_retries(
        self, 
        translation_input,
        translate_instructions, 
        start_line, end_line
        ) -> str:

        retries = 0
        translated_lines = ""

        while retries < MAX_RETRIES:
            translated_segment = self.processor.process_text(
                translation_input, translate_instructions)
            translated_lines = self._extract_lines(translated_segment)

            if not self._validate_lines(translated_lines, start_line, end_line):
                break  # Validation successful, exit loop

            retries += 1
            logger.warning(f"Validation failed for segment {start_line}-{end_line}, "
                           f"retrying (attempt {retries + 1}/{MAX_RETRIES})")
            # You might want to add a delay here, e.g., using time.sleep()

        if retries == MAX_RETRIES:
            logger.error(
                f"Validation failed after {MAX_RETRIES}"
                f" retries for segment {start_line}-{end_line}\n"
                "Using last generated result."
                )

        return self._extract_content(translated_lines)

    def _extract_content(self, lines: str) -> str:
        """convert line-numbered format to un-numbered text."""
        # clean each line and return full clean segment
        line_list = lines.splitlines()
        # Remove line numbering and strip whitespace
        stripped_lines = [line.split(':', 1)[-1].strip() for line in line_list]
        return "\n".join(stripped_lines)


    def _build_translation_input(
        self, preceding_context: str, transcript_segment: str, following_context: str
    ) -> str:
        """
        Build input text in required XML-style format.

        Args:
            preceding_context: Context lines before segment
            transcript_segment: Main segment to translate
            following_context: Context lines after segment

        Returns:
            Formatted input text
        """
        parts = []

        # Add preceding context if exists
        if preceding_context:
            parts.extend(
                [
                    PRECEDING_CONTEXT_MARKER,
                    preceding_context,
                    PRECEDING_CONTEXT_MARKER,
                    "",
                ]
            )

        # Add main segment (always required)
        parts.extend(
            [
                TRANSCRIPT_SEGMENT_MARKER,
                transcript_segment,
                TRANSCRIPT_SEGMENT_MARKER,
                "",
            ]
        )

        # Add following context if exists
        if following_context:
            parts.extend(
                [
                    FOLLOWING_CONTEXT_MARKER,
                    following_context,
                    FOLLOWING_CONTEXT_MARKER,
                    "",
                ]
            )

        return "\n".join(parts)

    def translate_text(
        self,
        text: TextObject,
        source_language: str,
        segment_size: Optional[int] = None,  
        target_language: str = DEFAULT_TARGET_LANGUAGE,
        template_dict: Optional[Dict] = None,
    ) -> TextObject:
        """
        Translate entire text in segments while maintaining line continuity.

        Args:
            text: Text to translate
            segment_size: Number of lines per translation segment
            source_language: Source language code
            target_language: Target language code (default: en for English)
            template_dict: Optional additional template values

        Returns:
            Complete translated text with line numbers preserved
        """

        # Use TextObject language if not specified
        if not source_language:
            source_language = text.language

        # Convert text to numbered lines
        num_text = text.num_text
        total_lines = num_text.size

        metadata = text.metadata

        if not segment_size:
            segment_size = _calculate_segment_size(
                num_text, DEFAULT_TRANSLATION_TARGET_TOKENS
            )

        translated_segments = []

        logger.debug(
            f"Total lines to translate: {total_lines} "
            f" | Translation segment size: {segment_size}."
        )
        # Process text in segments using segment iteration
        for start_idx, end_idx in num_text.iter_segments(
            segment_size, min_segment_size=MIN_SEGMENT_SIZE
        ):
            translated_content = self.translate_segment(
                num_text=num_text,
                start_line=start_idx,
                end_line=end_idx,
                metadata=metadata,
                source_language=source_language,
                target_language=target_language,
                template_dict=template_dict,
            )

            translated_segments.append(translated_content)

        new_text =  "\n".join(translated_segments).strip()

        return text.transform(
            data_str=new_text, 
            language=target_language, 
            )

    def _extract_lines(self, segment: str) -> str:
        if segment.startswith(TRANSCRIPT_SEGMENT_MARKER) and segment.endswith(
            TRANSCRIPT_SEGMENT_MARKER
        ):
            segment = segment[
                len(TRANSCRIPT_SEGMENT_MARKER) : -len(TRANSCRIPT_SEGMENT_MARKER)
            ]

        else:
            logger.warning("Translated segment missing transcript_segment tags")

        # clean each line and return full clean segment
        lines = segment.splitlines()
        stripped_lines = [line.strip() for line in lines]
        segment = "\n".join(stripped_lines)
        return segment.strip()

    def _validate_lines(
        self, translated_content: str, start_index: int, end_index: int
    ) -> bool:
        """
        Validate translated segment format, content, and line number sequence.
        Issues warnings for validation issues rather than raising errors.

        Args:
            translated_segment: Translated text to validate
            start_idx: the staring index of the range (inclusive)
            end_line: then ending index of the range (exclusive)

        Returns:
            str: Content with segment tags removed
        """

        # Validate lines

        error_count = 0
        lines = translated_content.splitlines()
        line_numbers = []

        start_line = start_index  # inclusive start
        end_line = end_index - 1  # exclusive end

        for line in lines:
            line = line.strip()
            if not line:
                continue

            if ":" not in line:
                logger.warning(f"Invalid line format: {line}")
                error_count += 1
                continue

            try:
                line_num = int(line[: line.index(":")])
                if line_num < 0:
                    logger.warning(f"Invalid line number: {line}")
                    error_count += 1
                    continue
                line_numbers.append(line_num)
            except ValueError:
                logger.warning(f"Line number parsing failed: {line}")
                error_count += 1
                continue

        # Validate sequence
        if not line_numbers:
            logger.warning("No valid line numbers found")
        else:
            if line_numbers[0] != start_line:
                logger.warning(
                    f"First line number {line_numbers[0]} "
                    f" doesn't match expected {start_line}"
                )
                error_count += 1

            if line_numbers[-1] != end_line:
                logger.warning(
                    f"Last line number {line_numbers[-1]} "
                    f"doesn't match expected {end_line}"
                )
                error_count += 1

            expected = set(range(start_line, end_line + 1))
            if missing := expected - set(line_numbers):
                logger.warning(f"Missing line numbers in sequence: {missing}")
                error_count += len(missing)

        logger.debug(f"Validated {len(lines)} lines from {start_line} to {end_line}\n"
                     f"{error_count} errors encountered.")
        return error_count > 0
context_lines = context_lines instance-attribute
pattern = pattern instance-attribute
processor = processor instance-attribute
review_count = review_count instance-attribute
style = style instance-attribute
__init__(processor, pattern, review_count=DEFAULT_REVIEW_COUNT, style=DEFAULT_TRANSLATE_STYLE, context_lines=DEFAULT_TRANSLATE_CONTEXT_LINES)

Initialize line translator.

Parameters:

Name Type Description Default
processor TextProcessor

Implementation of TextProcessor

required
pattern Prompt

Pattern object containing translation instructions

required
review_count int

Number of review passes

DEFAULT_REVIEW_COUNT
style str

Translation style to apply

DEFAULT_TRANSLATE_STYLE
context_lines int

Number of context lines to include before/after

DEFAULT_TRANSLATE_CONTEXT_LINES
Source code in src/tnh_scholar/ai_text_processing/line_translator.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(
    self,
    processor: TextProcessor,
    pattern: Prompt,
    review_count: int = DEFAULT_REVIEW_COUNT,
    style: str = DEFAULT_TRANSLATE_STYLE,
    # Number of context lines before/after
    context_lines: int = DEFAULT_TRANSLATE_CONTEXT_LINES,  
):
    """
    Initialize line translator.

    Args:
        processor: Implementation of TextProcessor
        pattern: Pattern object containing translation instructions
        review_count: Number of review passes
        style: Translation style to apply
        context_lines: Number of context lines to include before/after
    """
    self.processor = processor
    self.pattern = pattern
    self.review_count = review_count
    self.style = style
    self.context_lines = context_lines
translate_segment(num_text, start_line, end_line, metadata, target_language, source_language, template_dict=None)

Translate a segment of text with context.

Parameters:

Name Type Description Default
text

Full text to extract segment from

required
start_line int

Starting line number of segment

required
end_line int

Ending line number of segment

required
metadata Metadata

metadata for text

required
source_language str

Source language code

required
target_language str

Target language code (default: en for English)

required
template_dict Optional[Dict]

Optional additional template values

None

Returns:

Type Description
str

Translated text segment with line numbers preserved

Source code in src/tnh_scholar/ai_text_processing/line_translator.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def translate_segment(
    self,
    num_text: NumberedText,
    start_line: int,
    end_line: int,
    metadata: Metadata,
    target_language: str,
    source_language: str,
    template_dict: Optional[Dict] = None,
) -> str:
    """
    Translate a segment of text with context.

    Args:
        text: Full text to extract segment from
        start_line: Starting line number of segment
        end_line: Ending line number of segment
        metadata: metadata for text
        source_language: Source language code
        target_language: Target language code (default: en for English)
        template_dict: Optional additional template values

    Returns:
        Translated text segment with line numbers preserved
    """

    # Calculate context ranges
    preceding_start = max(1, start_line - self.context_lines)  # lines start on 1.
    following_end = min(num_text.end + 1, end_line + self.context_lines)

    # Extract context and segment
    preceding_context = num_text.get_numbered_segment(preceding_start, start_line)
    transcript_segment = num_text.get_numbered_segment(start_line, end_line)
    following_context = num_text.get_numbered_segment(end_line, following_end)

    # build input text
    translation_input = self._build_translation_input(
        preceding_context, transcript_segment, following_context
    )

    # Prepare template values
    template_values = {
        "source_language": get_language_from_code(source_language),
        "target_language": get_language_from_code(target_language),
        "review_count": self.review_count,
        "style": self.style,
        "metadata": metadata.to_yaml()
    }

    if template_dict:
        template_values |= template_dict

    # Get and apply translation instructions
    logger.info(f"Translating segment (lines {start_line}-{end_line})")
    translate_instructions = self.pattern.apply_template(template_values)

    if start_line <= 1:
        logger.debug(
            f"Translate instructions (first segment):\n{translate_instructions}"
        )
    logger.debug(f"Translation input:\n{translation_input}")

    return self._translate_with_retries(
        translation_input,
        translate_instructions,
        start_line,
        end_line,
        )
translate_text(text, source_language, segment_size=None, target_language=DEFAULT_TARGET_LANGUAGE, template_dict=None)

Translate entire text in segments while maintaining line continuity.

Parameters:

Name Type Description Default
text TextObject

Text to translate

required
segment_size Optional[int]

Number of lines per translation segment

None
source_language str

Source language code

required
target_language str

Target language code (default: en for English)

DEFAULT_TARGET_LANGUAGE
template_dict Optional[Dict]

Optional additional template values

None

Returns:

Type Description
TextObject

Complete translated text with line numbers preserved

Source code in src/tnh_scholar/ai_text_processing/line_translator.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def translate_text(
    self,
    text: TextObject,
    source_language: str,
    segment_size: Optional[int] = None,  
    target_language: str = DEFAULT_TARGET_LANGUAGE,
    template_dict: Optional[Dict] = None,
) -> TextObject:
    """
    Translate entire text in segments while maintaining line continuity.

    Args:
        text: Text to translate
        segment_size: Number of lines per translation segment
        source_language: Source language code
        target_language: Target language code (default: en for English)
        template_dict: Optional additional template values

    Returns:
        Complete translated text with line numbers preserved
    """

    # Use TextObject language if not specified
    if not source_language:
        source_language = text.language

    # Convert text to numbered lines
    num_text = text.num_text
    total_lines = num_text.size

    metadata = text.metadata

    if not segment_size:
        segment_size = _calculate_segment_size(
            num_text, DEFAULT_TRANSLATION_TARGET_TOKENS
        )

    translated_segments = []

    logger.debug(
        f"Total lines to translate: {total_lines} "
        f" | Translation segment size: {segment_size}."
    )
    # Process text in segments using segment iteration
    for start_idx, end_idx in num_text.iter_segments(
        segment_size, min_segment_size=MIN_SEGMENT_SIZE
    ):
        translated_content = self.translate_segment(
            num_text=num_text,
            start_line=start_idx,
            end_line=end_idx,
            metadata=metadata,
            source_language=source_language,
            target_language=target_language,
            template_dict=template_dict,
        )

        translated_segments.append(translated_content)

    new_text =  "\n".join(translated_segments).strip()

    return text.transform(
        data_str=new_text, 
        language=target_language, 
        )
translate_text_by_lines(text, source_language=None, target_language=DEFAULT_TARGET_LANGUAGE, pattern=None, model=None, style=None, segment_size=None, context_lines=None, review_count=None, template_dict=None)
Source code in src/tnh_scholar/ai_text_processing/line_translator.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
def translate_text_by_lines(
    text: TextObject,
    source_language: Optional[str] = None,
    target_language: str = DEFAULT_TARGET_LANGUAGE,
    pattern: Optional[Prompt] = None,
    model: Optional[str] = None,
    style: Optional[str] = None,
    segment_size: Optional[int] = None,
    context_lines: Optional[int] = None,
    review_count: Optional[int] = None,
    template_dict: Optional[Dict] = None,
) -> TextObject:

    if source_language is None:
        source_language = text.language

    if pattern is None:
        pattern = get_pattern(DEFAULT_TRANSLATION_PATTERN)

    processor = OpenAIProcessor(model)

    translator = LineTranslator(
        processor=processor,
        pattern=pattern,
        style=style or DEFAULT_TRANSLATE_STYLE,
        context_lines=context_lines or DEFAULT_TRANSLATE_CONTEXT_LINES,
        review_count=review_count or DEFAULT_REVIEW_COUNT,
    )

    process_metadata = ProcessMetadata(
            step="translation",
            processor="LineTranslator",
            model=processor.model,
            source_language=source_language,
            target_language=target_language,
            segment_size=segment_size,
            context_lines=translator.context_lines,
            review_count=translator.review_count,
            style=translator.style,
            template_dict=template_dict,
        )

    text = translator.translate_text(
        text,
        source_language=source_language,
        target_language=target_language,
        segment_size=segment_size,
        template_dict=template_dict,
    )
    return text.transform(process_metadata=process_metadata)

openai_process_interface

TOKEN_BUFFER = 500 module-attribute
logger = get_child_logger(__name__) module-attribute
openai_process_text(text_input, process_instructions, model=None, response_format=None, batch=False, max_tokens=0)

postprocessing a transcription.

Source code in src/tnh_scholar/ai_text_processing/openai_process_interface.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def openai_process_text(
    text_input: str,
    process_instructions: str,
    model: Optional[str] = None,
    response_format: Optional[Type[BaseModel]] = None,
    batch: bool = False,
    max_tokens: int = 0,
) -> Union[BaseModel, str]:
    """postprocessing a transcription."""

    user_prompts = [text_input]
    system_message = process_instructions

    logger.debug(f"OpenAI Process Text with process instructions:\n{system_message}")
    if max_tokens == 0:
        tokens = token_count(text_input)
        max_tokens = tokens + TOKEN_BUFFER

    model_name = model or "default"

    logger.info(
        f"Open AI Text Processing{' as batch process' if batch else ''} "
        f"with model '{model_name}' initiated.\n"
        f"Requesting a maximum of {max_tokens} tokens."
    )

    if batch:
        return _run_batch_process_text(
            user_prompts, system_message, max_tokens, model_name, response_format
        )

    completion_result = simple_completion(
        system_message=system_message,
        user_message=text_input,
        model=model,
        max_tokens=max_tokens,
        response_model=response_format,
    )
    logger.info("Processing completed.")
    return completion_result

prompts

MANAGER_UPDATE_MESSAGE = 'PromptManager Update:' module-attribute
MarkdownStr = NewType('MarkdownStr', str) module-attribute
logger = get_child_logger(__name__) module-attribute
ConcurrentAccessManager

Manages concurrent access to prompt files.

Provides: - File-level locking - Safe concurrent access prompts - Lock cleanup

Source code in src/tnh_scholar/ai_text_processing/prompts.py
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
class ConcurrentAccessManager:
    """
    Manages concurrent access to prompt files.

    Provides:
    - File-level locking
    - Safe concurrent access prompts
    - Lock cleanup
    """

    def __init__(self, lock_dir: Path):
        """
        Initialize access manager.

        Args:
            lock_dir: Directory for lock files
        """
        self.lock_dir = Path(lock_dir)
        self._ensure_lock_dir()
        self._cleanup_stale_locks()

    def _ensure_lock_dir(self) -> None:
        """Create lock directory if it doesn't exist."""
        self.lock_dir.mkdir(parents=True, exist_ok=True)

    def _cleanup_stale_locks(self, max_age: timedelta = timedelta(hours=1)) -> None:
        """
        Remove stale lock files.

        Args:
            max_age: Maximum age for lock files before considered stale
        """
        current_time = datetime.now()
        for lock_file in self.lock_dir.glob("*.lock"):
            try:
                mtime = datetime.fromtimestamp(lock_file.stat().st_mtime)
                if current_time - mtime > max_age:
                    lock_file.unlink()
                    logger.warning(f"Removed stale lock file: {lock_file}")
            except FileNotFoundError:
                # Lock was removed by another process
                pass
            except Exception as e:
                logger.error(f"Error cleaning up lock file {lock_file}: {e}")

    @contextmanager
    def file_lock(self, file_path: Path):
        """
        Context manager for safely accessing files.

        Args:
            file_path: Path to file to lock

        Yields:
            None when lock is acquired

        Raises:
            RuntimeError: If file is already locked
            OSError: If lock file operations fail
        """
        file_path = Path(file_path)
        lock_file_path = self.lock_dir / f"{file_path.stem}.lock"
        lock_fd = None

        try:
            # Open or create lock file
            lock_fd = os.open(str(lock_file_path), os.O_WRONLY | os.O_CREAT)

            try:
                # Attempt to acquire lock
                fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)

                # Write process info to lock file
                pid = os.getpid()
                timestamp = datetime.now().isoformat()
                os.write(lock_fd, f"{pid} {timestamp}\n".encode())

                logger.debug(f"Acquired lock for {file_path}")
                yield

            except BlockingIOError as e:
                raise RuntimeError(
                    f"File {file_path} is locked by another process"
                ) from e

        except OSError as e:
            logger.error(f"Lock operation failed for {file_path}: {e}")
            raise

        finally:
            if lock_fd is not None:
                try:
                    # Release lock and close file descriptor
                    fcntl.flock(lock_fd, fcntl.LOCK_UN)
                    os.close(lock_fd)

                    # Remove lock file
                    lock_file_path.unlink(missing_ok=True)
                    logger.debug(f"Released lock for {file_path}")

                except Exception as e:
                    logger.error(f"Error cleaning up lock for {file_path}: {e}")

    def is_locked(self, file_path: Path) -> bool:
        """
        Check if a file is currently locked.

        Args:
            file_path: Path to file to check

        Returns:
            bool: True if file is locked
        """
        lock_file_path = self.lock_dir / f"{file_path.stem}.lock"

        if not lock_file_path.exists():
            return False

        try:
            with open(lock_file_path, "r") as f:
                # Try to acquire and immediately release lock
                fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB)
                fcntl.flock(f, fcntl.LOCK_UN)
                return False
        except BlockingIOError:
            return True
        except Exception:
            return False
lock_dir = Path(lock_dir) instance-attribute
__init__(lock_dir)

Initialize access manager.

Parameters:

Name Type Description Default
lock_dir Path

Directory for lock files

required
Source code in src/tnh_scholar/ai_text_processing/prompts.py
575
576
577
578
579
580
581
582
583
584
def __init__(self, lock_dir: Path):
    """
    Initialize access manager.

    Args:
        lock_dir: Directory for lock files
    """
    self.lock_dir = Path(lock_dir)
    self._ensure_lock_dir()
    self._cleanup_stale_locks()
file_lock(file_path)

Context manager for safely accessing files.

Parameters:

Name Type Description Default
file_path Path

Path to file to lock

required

Yields:

Type Description

None when lock is acquired

Raises:

Type Description
RuntimeError

If file is already locked

OSError

If lock file operations fail

Source code in src/tnh_scholar/ai_text_processing/prompts.py
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
@contextmanager
def file_lock(self, file_path: Path):
    """
    Context manager for safely accessing files.

    Args:
        file_path: Path to file to lock

    Yields:
        None when lock is acquired

    Raises:
        RuntimeError: If file is already locked
        OSError: If lock file operations fail
    """
    file_path = Path(file_path)
    lock_file_path = self.lock_dir / f"{file_path.stem}.lock"
    lock_fd = None

    try:
        # Open or create lock file
        lock_fd = os.open(str(lock_file_path), os.O_WRONLY | os.O_CREAT)

        try:
            # Attempt to acquire lock
            fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)

            # Write process info to lock file
            pid = os.getpid()
            timestamp = datetime.now().isoformat()
            os.write(lock_fd, f"{pid} {timestamp}\n".encode())

            logger.debug(f"Acquired lock for {file_path}")
            yield

        except BlockingIOError as e:
            raise RuntimeError(
                f"File {file_path} is locked by another process"
            ) from e

    except OSError as e:
        logger.error(f"Lock operation failed for {file_path}: {e}")
        raise

    finally:
        if lock_fd is not None:
            try:
                # Release lock and close file descriptor
                fcntl.flock(lock_fd, fcntl.LOCK_UN)
                os.close(lock_fd)

                # Remove lock file
                lock_file_path.unlink(missing_ok=True)
                logger.debug(f"Released lock for {file_path}")

            except Exception as e:
                logger.error(f"Error cleaning up lock for {file_path}: {e}")
is_locked(file_path)

Check if a file is currently locked.

Parameters:

Name Type Description Default
file_path Path

Path to file to check

required

Returns:

Name Type Description
bool bool

True if file is locked

Source code in src/tnh_scholar/ai_text_processing/prompts.py
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
def is_locked(self, file_path: Path) -> bool:
    """
    Check if a file is currently locked.

    Args:
        file_path: Path to file to check

    Returns:
        bool: True if file is locked
    """
    lock_file_path = self.lock_dir / f"{file_path.stem}.lock"

    if not lock_file_path.exists():
        return False

    try:
        with open(lock_file_path, "r") as f:
            # Try to acquire and immediately release lock
            fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB)
            fcntl.flock(f, fcntl.LOCK_UN)
            return False
    except BlockingIOError:
        return True
    except Exception:
        return False
GitBackedRepository

Manages versioned storage of prompts using Git.

Provides basic Git operations while hiding complexity: - Automatic versioning of changes - Basic conflict resolution - History tracking

Source code in src/tnh_scholar/ai_text_processing/prompts.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
class GitBackedRepository:
    """
    Manages versioned storage of prompts using Git.

    Provides basic Git operations while hiding complexity:
    - Automatic versioning of changes
    - Basic conflict resolution
    - History tracking
    """

    def __init__(self, repo_path: Path):
        """
        Initialize or connect to Git repository.

        Args:
            repo_path: Path to repository directory

        Raises:
            GitCommandError: If Git operations fail
        """
        self.repo_path = repo_path

        try:
            # Try to connect to existing repository
            self.repo = Repo(repo_path)
            logger.debug(f"Connected to existing Git repository at {repo_path}")

        except InvalidGitRepositoryError:
            # Initialize new repository if none exists
            logger.info(f"Initializing new Git repository at {repo_path}")
            self.repo = Repo.init(repo_path)

            # Create initial commit if repo is empty
            if not self.repo.head.is_valid():
                # Create and commit .gitignore
                gitignore = repo_path / ".gitignore"
                gitignore.write_text("*.lock\n.DS_Store\n")
                self.repo.index.add([".gitignore"])
                self.repo.index.commit("Initial repository setup")

    def update_file(self, file_path: Path) -> str:
        """
        Stage and commit changes to a file in the Git repository.

        Args:
            file_path: Absolute or relative path to the file.

        Returns:
            str: Commit hash if changes were made.

        Raises:
            FileNotFoundError: If the file does not exist.
            ValueError: If the file is outside the repository.
            GitCommandError: If Git operations fail.
        """
        file_path = file_path.resolve()

        # Ensure the file is within the repository
        try:
            rel_path = file_path.relative_to(self.repo_path)
        except ValueError as e:
            raise ValueError(
                f"File {file_path} is not under the repository root {self.repo_path}"
            ) from e

        if not file_path.exists():
            raise FileNotFoundError(f"File does not exist: {file_path}")

        try:
            return self._commit_file_update(rel_path, file_path)
        except GitCommandError as e:
            logger.error(f"Git operation failed: {e}")
            raise

    def _commit_file_update(self, rel_path, file_path):
        if self._is_file_clean(rel_path):
            # Return the current commit hash if no changes
            return self.repo.head.commit.hexsha

        logger.info(f"Detected changes in {rel_path}, updating version control.")
        self.repo.index.add([str(rel_path)])
        commit = self.repo.index.commit(
            f"{MANAGER_UPDATE_MESSAGE} {rel_path.stem}",
            author=Actor("PromptManager", ""),
        )
        logger.info(f"Committed changes to {file_path}: {commit.hexsha}")
        return commit.hexsha

    def _get_file_revisions(self, file_path: Path) -> List[Commit]:
        """
        Get ordered list of commits that modified a file, most recent first.

        Args:
            file_path: Path to file relative to repository root

        Returns:
            List of Commit objects affecting this file

        Raises:
            GitCommandError: If Git operations fail
        """
        rel_path = file_path.relative_to(self.repo_path)
        try:
            return list(self.repo.iter_commits(paths=str(rel_path)))
        except GitCommandError as e:
            logger.error(f"Failed to get commits for {rel_path}: {e}")
            return []

    def _get_commit_diff(
        self, commit: Commit, file_path: Path, prev_commit: Optional[Commit] = None
    ) -> Tuple[str, str]:
        """
        Get both stat and detailed diff for a commit.

        Args:
            commit: Commit to diff
            file_path: Path relative to repository root
            prev_commit: Previous commit for diff, defaults to commit's parent

        Returns:
            Tuple of (stat_diff, detailed_diff) where:
                stat_diff: Summary of changes (files changed, insertions/deletions)
                detailed_diff: Colored word-level diff with context

        Raises:
            GitCommandError: If Git operations fail
        """
        prev_hash = prev_commit.hexsha if prev_commit else f"{commit.hexsha}^"
        rel_path = file_path.relative_to(self.repo_path)

        try:
            # Get stats diff
            stat = self.repo.git.diff(prev_hash, commit.hexsha, rel_path, stat=True)

            # Get detailed diff
            diff = self.repo.git.diff(
                prev_hash,
                commit.hexsha,
                rel_path,
                unified=2,
                word_diff="plain",
                color="always",
                ignore_space_change=True,
            )

            return stat, diff
        except GitCommandError as e:
            logger.error(f"Failed to get diff for {commit.hexsha}: {e}")
            return "", ""

    def display_history(self, file_path: Path, max_versions: int = 0) -> None:
        """
        Display history of changes for a file with diffs between versions.

        Shows most recent changes first, limited to max_versions entries.
        For each change shows:
        - Commit info and date
        - Stats summary of changes
        - Detailed color diff with 2 lines of context

        Args:
            file_path: Path to file in repository
            max_versions: Maximum number of versions to show, 
            if zero, shows all revisions.

        Example:
            >>> repo.display_history(Path("prompts/format_dharma_talk.yaml"))
            Commit abc123def (2024-12-28 14:30:22):
            1 file changed, 5 insertions(+), 2 deletions(-)

            diff --git a/prompts/format_dharma_talk.yaml ...
            ...
        """

        try:
            # Get commit history
            commits = self._get_file_revisions(file_path)
            if not commits:
                print(f"No history found for {file_path}")
                return

            if max_versions == 0:
                max_versions = len(commits)  # look at all commits.

            # Display limited history with diffs
            for i, commit in enumerate(commits[:max_versions]):
                # Print commit header
                date_str = commit.committed_datetime.strftime("%Y-%m-%d %H:%M:%S")
                print(f"\nCommit {commit.hexsha[:8]} ({date_str}):")
                print(f"Message: {commit.message.strip()}")

                # Get and display diffs
                prev_commit = commits[i + 1] if i + 1 < len(commits) else None
                stat_diff, detailed_diff = self._get_commit_diff(
                    commit, file_path, prev_commit
                )

                if stat_diff:
                    print("\nChanges:")
                    print(stat_diff)
                if detailed_diff:
                    print("\nDetailed diff:")
                    print(detailed_diff)

                print("\033[0m", end="")
                print("-" * 80)  # Visual separator between commits

        except Exception as e:
            logger.error(f"Failed to display history for {file_path}: {e}")
            print(f"Error displaying history: {e}")
            raise

    def _is_file_clean(self, rel_path: Path) -> bool:
        """
        Check if file has uncommitted changes.

        Args:
            rel_path: Path relative to repository root

        Returns:
            bool: True if file has no changes
        """
        return str(rel_path) not in (
            [item.a_path for item in self.repo.index.diff(None)]
            + self.repo.untracked_files
        )
repo = Repo(repo_path) instance-attribute
repo_path = repo_path instance-attribute
__init__(repo_path)

Initialize or connect to Git repository.

Parameters:

Name Type Description Default
repo_path Path

Path to repository directory

required

Raises:

Type Description
GitCommandError

If Git operations fail

Source code in src/tnh_scholar/ai_text_processing/prompts.py
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
def __init__(self, repo_path: Path):
    """
    Initialize or connect to Git repository.

    Args:
        repo_path: Path to repository directory

    Raises:
        GitCommandError: If Git operations fail
    """
    self.repo_path = repo_path

    try:
        # Try to connect to existing repository
        self.repo = Repo(repo_path)
        logger.debug(f"Connected to existing Git repository at {repo_path}")

    except InvalidGitRepositoryError:
        # Initialize new repository if none exists
        logger.info(f"Initializing new Git repository at {repo_path}")
        self.repo = Repo.init(repo_path)

        # Create initial commit if repo is empty
        if not self.repo.head.is_valid():
            # Create and commit .gitignore
            gitignore = repo_path / ".gitignore"
            gitignore.write_text("*.lock\n.DS_Store\n")
            self.repo.index.add([".gitignore"])
            self.repo.index.commit("Initial repository setup")
display_history(file_path, max_versions=0)

Display history of changes for a file with diffs between versions.

Shows most recent changes first, limited to max_versions entries. For each change shows: - Commit info and date - Stats summary of changes - Detailed color diff with 2 lines of context

Parameters:

Name Type Description Default
file_path Path

Path to file in repository

required
max_versions int

Maximum number of versions to show,

0
Example

repo.display_history(Path("prompts/format_dharma_talk.yaml")) Commit abc123def (2024-12-28 14:30:22): 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/prompts/format_dharma_talk.yaml ... ...

Source code in src/tnh_scholar/ai_text_processing/prompts.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
def display_history(self, file_path: Path, max_versions: int = 0) -> None:
    """
    Display history of changes for a file with diffs between versions.

    Shows most recent changes first, limited to max_versions entries.
    For each change shows:
    - Commit info and date
    - Stats summary of changes
    - Detailed color diff with 2 lines of context

    Args:
        file_path: Path to file in repository
        max_versions: Maximum number of versions to show, 
        if zero, shows all revisions.

    Example:
        >>> repo.display_history(Path("prompts/format_dharma_talk.yaml"))
        Commit abc123def (2024-12-28 14:30:22):
        1 file changed, 5 insertions(+), 2 deletions(-)

        diff --git a/prompts/format_dharma_talk.yaml ...
        ...
    """

    try:
        # Get commit history
        commits = self._get_file_revisions(file_path)
        if not commits:
            print(f"No history found for {file_path}")
            return

        if max_versions == 0:
            max_versions = len(commits)  # look at all commits.

        # Display limited history with diffs
        for i, commit in enumerate(commits[:max_versions]):
            # Print commit header
            date_str = commit.committed_datetime.strftime("%Y-%m-%d %H:%M:%S")
            print(f"\nCommit {commit.hexsha[:8]} ({date_str}):")
            print(f"Message: {commit.message.strip()}")

            # Get and display diffs
            prev_commit = commits[i + 1] if i + 1 < len(commits) else None
            stat_diff, detailed_diff = self._get_commit_diff(
                commit, file_path, prev_commit
            )

            if stat_diff:
                print("\nChanges:")
                print(stat_diff)
            if detailed_diff:
                print("\nDetailed diff:")
                print(detailed_diff)

            print("\033[0m", end="")
            print("-" * 80)  # Visual separator between commits

    except Exception as e:
        logger.error(f"Failed to display history for {file_path}: {e}")
        print(f"Error displaying history: {e}")
        raise
update_file(file_path)

Stage and commit changes to a file in the Git repository.

Parameters:

Name Type Description Default
file_path Path

Absolute or relative path to the file.

required

Returns:

Name Type Description
str str

Commit hash if changes were made.

Raises:

Type Description
FileNotFoundError

If the file does not exist.

ValueError

If the file is outside the repository.

GitCommandError

If Git operations fail.

Source code in src/tnh_scholar/ai_text_processing/prompts.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def update_file(self, file_path: Path) -> str:
    """
    Stage and commit changes to a file in the Git repository.

    Args:
        file_path: Absolute or relative path to the file.

    Returns:
        str: Commit hash if changes were made.

    Raises:
        FileNotFoundError: If the file does not exist.
        ValueError: If the file is outside the repository.
        GitCommandError: If Git operations fail.
    """
    file_path = file_path.resolve()

    # Ensure the file is within the repository
    try:
        rel_path = file_path.relative_to(self.repo_path)
    except ValueError as e:
        raise ValueError(
            f"File {file_path} is not under the repository root {self.repo_path}"
        ) from e

    if not file_path.exists():
        raise FileNotFoundError(f"File does not exist: {file_path}")

    try:
        return self._commit_file_update(rel_path, file_path)
    except GitCommandError as e:
        logger.error(f"Git operation failed: {e}")
        raise
LocalPromptManager

A simple singleton implementation of PromptManager that ensures only one instance is created and reused throughout the application lifecycle.

This class wraps the PromptManager to provide efficient prompt loading by maintaining a single reusable instance.

Attributes:

Name Type Description
_instance Optional[SingletonPromptManager]

The singleton instance

_prompt_manager Optional[PromptManager]

The wrapped PromptManager instance

Source code in src/tnh_scholar/ai_text_processing/prompts.py
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
class LocalPromptManager:
    """
    A simple singleton implementation of PromptManager that ensures only one instance
    is created and reused throughout the application lifecycle.

    This class wraps the PromptManager to provide efficient prompt loading by
    maintaining a single reusable instance.

    Attributes:
        _instance (Optional[SingletonPromptManager]): The singleton instance
        _prompt_manager (Optional[PromptManager]): The wrapped PromptManager instance
    """

    _instance: Optional["LocalPromptManager"] = None

    def __new__(cls) -> "LocalPromptManager":
        """
        Create or return the singleton instance.

        Returns:
            SingletonPromptManager: The singleton instance
        """
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._prompt_manager = None
        return cls._instance

    @property
    def prompt_manager(self) -> "PromptCatalog":
        """
        Lazy initialization of the PromptManager instance.

        Returns:
            PromptManager: The wrapped PromptManager instance

        Raises:
            RuntimeError: If PATTERN_REPO is not properly configured
        """
        if self._prompt_manager is None:  # type: ignore
            try:
                load_dotenv()
                if prompt_path_name := os.getenv("TNH_PATTERN_DIR"):
                    prompt_dir = Path(prompt_path_name)
                    logger.debug(f"prompt dir: {prompt_path_name}")
                else:
                    prompt_dir = TNH_DEFAULT_PATTERN_DIR
                self._prompt_manager = PromptCatalog(prompt_dir)
            except ImportError as err:
                raise RuntimeError(
                    "Failed to initialize PromptManager. Ensure prompt_manager "
                    f"module and PATTERN_REPO are properly configured: {err}"
                ) from err
        return self._prompt_manager

    def get_prompt(self, name: str) -> Prompt:
        """Get a prompt by name."""
        return self.prompt_manager.load(Prompt._normalize_name(name))
prompt_manager property

Lazy initialization of the PromptManager instance.

Returns:

Name Type Description
PromptManager PromptCatalog

The wrapped PromptManager instance

Raises:

Type Description
RuntimeError

If PATTERN_REPO is not properly configured

__new__()

Create or return the singleton instance.

Returns:

Name Type Description
SingletonPromptManager LocalPromptManager

The singleton instance

Source code in src/tnh_scholar/ai_text_processing/prompts.py
946
947
948
949
950
951
952
953
954
955
956
def __new__(cls) -> "LocalPromptManager":
    """
    Create or return the singleton instance.

    Returns:
        SingletonPromptManager: The singleton instance
    """
    if cls._instance is None:
        cls._instance = super().__new__(cls)
        cls._instance._prompt_manager = None
    return cls._instance
get_prompt(name)

Get a prompt by name.

Source code in src/tnh_scholar/ai_text_processing/prompts.py
985
986
987
def get_prompt(self, name: str) -> Prompt:
    """Get a prompt by name."""
    return self.prompt_manager.load(Prompt._normalize_name(name))
Prompt

Base Prompt class for version-controlled template prompts.

Prompts contain: - Instructions: The main prompt instructions as a Jinja2 template. Note: Instructions are intended to be saved in markdown format in a .md file. - Template fields: Default values for template variables - Metadata: Name and identifier information

Version control is handled externally through Git, not in the prompt itself. Prompt identity is determined by the combination of identifiers.

Attributes:

Name Type Description
name str

The name of the prompt

instructions str

The Jinja2 template string for this prompt

default_template_fields Dict[str, str]

Default values for template variables

_allow_empty_vars bool

Whether to allow undefined template variables

_env Environment

Configured Jinja2 environment instance

Source code in src/tnh_scholar/ai_text_processing/prompts.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
class Prompt:
    """
    Base Prompt class for version-controlled template prompts.

    Prompts contain:
    - Instructions: The main prompt instructions as a Jinja2 template.
       Note: Instructions are intended to be saved in markdown format in a .md file.
    - Template fields: Default values for template variables
    - Metadata: Name and identifier information

    Version control is handled externally through Git, not in the prompt itself.
    Prompt identity is determined by the combination of identifiers.

    Attributes:
        name (str): The name of the prompt
        instructions (str): The Jinja2 template string for this prompt
        default_template_fields (Dict[str, str]): Default values for template variables
        _allow_empty_vars (bool): Whether to allow undefined template variables
        _env (Environment): Configured Jinja2 environment instance
    """

    @staticmethod
    def _normalize_name(value: str) -> str:
        """Canonicalize prompt names for case-insensitive handling.

        Currently: strip() + lower(). If future rules are needed (e.g.,
        removing punctuation, limiting length), implement them here.
        """
        return value.strip().lower()

    def __init__(
        self,
        name: str,
        instructions: MarkdownStr,
        path: Optional[Path] = None,
        default_template_fields: Optional[Dict[str, str]] = None,
        allow_empty_vars: bool = False,        
    ) -> None:
        """
        Initialize a new Prompt instance.

        Args:
            name: Unique name identifying the prompt
            instructions: Jinja2 template string containing the prompt
            default_template_fields: Optional default values for template variables
            allow_empty_vars: Whether to allow undefined template variables

        Raises:
            ValueError: If name or instructions are empty
            TemplateError: If template syntax is invalid
        """
        if not name or not instructions:
            raise ValueError("Name and instructions must not be empty")

        # Normalize prompt name to lowercase for case-insensitive handling
        name = Prompt._normalize_name(name)

        self.name = name
        self.instructions = instructions
        self.path = path
        self.default_template_fields = default_template_fields or {}
        self._allow_empty_vars = allow_empty_vars
        self._env = self._create_environment()

        # Validate template syntax on initialization
        self._validate_template()

    @staticmethod
    def _create_environment() -> Environment:
        """
        Create and configure a Jinja2 environment with optimal settings.

        Returns:
            Environment: Configured Jinja2 environment 
            with security and formatting options
        """
        return Environment(
            undefined=StrictUndefined,  # Raise errors for undefined variables
            trim_blocks=True,  # Remove first newline after a block
            lstrip_blocks=True,  # Strip tabs and spaces from the start of lines
            autoescape=True,  # Enable autoescaping for security
        )

    def _validate_template(self) -> None:
        """
        Validate the template syntax without rendering.

        Raises:
            TemplateError: If template syntax is invalid
        """
        try:
            self._env.parse(self.instructions)
        except TemplateError as e:
            raise TemplateError(
                f"Invalid template syntax in prompt '{self.name}': {str(e)}"
            ) from e

    def apply_template(self, field_values: Optional[Dict[str, str]] = None) -> str:
        """
        Apply template values to prompt instructions using Jinja2.

        Values precedence (highest to lowest):
        1. field_values (explicitly passed)
        2. frontmatter values (from prompt file)
        3. default_template_fields (prompt defaults)

        Args:
            field_values: Values to substitute into the template.
                        If None, uses frontmatter/defaults.

        Returns:
            str: Rendered instructions with template values applied.

        Raises:
            TemplateError: If template rendering fails
            ValueError: If required template variables are missing
        """
        # Get frontmatter values
        frontmatter = self.extract_frontmatter() or {}

        # Combine values with correct precedence using | operator
        template_values = self.default_template_fields | \
            frontmatter | (field_values or {})

        instructions = self.get_content_without_frontmatter()
        logger.debug(f"instructions without frontmatter:\n{instructions}")

        try:
            return self._render_template_with_values(instructions, template_values)
        except TemplateError as e:
            raise TemplateError(
                f"Template rendering failed for prompt '{self.name}': {str(e)}"
                ) from e

    def _render_template_with_values(
        self, 
        instructions: str, 
        template_values: dict
        ) -> str:
        """
        Validate and render template with provided values.

        Args:
            instructions: Template content without frontmatter
            template_values: Values to substitute into template

        Returns:
            Rendered template string

        Raises:
            ValueError: If required template variables are missing
        """
        # Parse for validation
        parsed_content = self._env.parse(instructions)
        required_vars = find_undeclared_variables(parsed_content)

        # Validate variables
        missing_vars = required_vars - set(template_values.keys())
        if missing_vars and not self._allow_empty_vars:
            raise ValueError(
                f"Missing required template variables in prompt '{self.name}': "
                f"{', '.join(sorted(missing_vars))}"
            )

        # Create and render template
        template = self._env.from_string(instructions)
        return template.render(**template_values)

    def extract_frontmatter(self) -> Optional[Dict[str, Any]]:
        """
        Extract and validate YAML frontmatter from markdown instructions.

        Returns:
            Optional[Dict]: Frontmatter data if found and valid, None otherwise

        Note:
            Frontmatter must be at the very start of the file and properly formatted.
        """

        prompt = r"\A---\s*\n(.*?)\n---\s*(?:\n|$)"
        if match := re.match(prompt, self.instructions, re.DOTALL):
            try:
                frontmatter = yaml.safe_load(match[1])
                if frontmatter is None:
                    return None
                if not isinstance(frontmatter, dict):
                    logger.warning(f"Frontmatter must be a YAML dictionary: "
                                   f"{frontmatter}")
                    return None
                return frontmatter
            except yaml.YAMLError as e:
                logger.warning(f"Invalid YAML in frontmatter: {e}")
                return None
        return None

    def get_content_without_frontmatter(self) -> str:
        """
        Get markdown content with frontmatter removed.

        Returns:
            str: Markdown content without frontmatter
        """
        prompt = r"\A---\s*\n.*?\n---\s*\n"
        return re.sub(prompt, "", self.instructions, flags=re.DOTALL)

    def update_frontmatter(self, new_data: Dict[str, Any]) -> None:
        """
        Update or add frontmatter to the markdown content.

        Args:
            new_data: Dictionary of frontmatter fields to update
        """

        current_frontmatter = self.extract_frontmatter() or {}
        updated_frontmatter = {**current_frontmatter, **new_data}

        # Create YAML string
        yaml_str = yaml.dump(
            updated_frontmatter, default_flow_style=False, allow_unicode=True
        )

        # Remove existing frontmatter if present
        content = self.get_content_without_frontmatter()

        # Combine new frontmatter with content
        self.instructions = f"---\n{yaml_str}---\n\n{content}"


    def source_bytes(self) -> bytes:
        """
        Best-effort raw bytes for prompt hashing.

        Prefers hashing exact on-disk bytes including front-matter.
        We therefore first try to read from `prompt_path`. If that fails, we fall back
        to hashing the concatenation of known templates. In V1, only
        the instructions (system template) are used for rendering.
        """
        # Preferred path: use on-disk bytes when available.
        if self.path is not None:
            return self.path.read_bytes()

        # Fallback: concatenate known templates deterministically
        sys_part = self.instructions or ""
        return sys_part.encode("utf-8")

    def content_hash(self) -> str:
        """
        Generate a SHA-256 hash of the prompt content.

        Useful for quick content comparison and change detection.

        Returns:
            str: Hexadecimal string of the SHA-256 hash
        """
        content = (
            f"{self.name}{self.instructions}"
            f"{sorted(self.default_template_fields.items())}"
            )
        return hashlib.sha256(content.encode("utf-8")).hexdigest()

    def to_dict(self) -> Dict[str, Any]:
        """
        Convert prompt to dictionary for serialization.

        Returns:
            Dict containing all prompt data in serializable format
        """
        return {
            "name": self.name,
            "instructions": self.instructions,
            "default_template_fields": self.default_template_fields,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "Prompt":
        """
        Create prompt instance from dictionary data.

        Args:
            data: Dictionary containing prompt data

        Returns:
            Prompt: New prompt instance

        Raises:
            ValueError: If required fields are missing
        """
        required_fields = {"name", "instructions"}
        if missing_fields := required_fields - set(data.keys()):
            raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")

        return cls(
            name=Prompt._normalize_name(str(data["name"])),
            instructions=data["instructions"],
            path=None,
            default_template_fields=data.get("default_template_fields", {}),
        )

    def __eq__(self, other: object) -> bool:
        """Compare prompts based on their content."""
        if not isinstance(other, Prompt):
            return NotImplemented
        return self.content_hash() == other.content_hash()

    def __hash__(self) -> int:
        """Hash based on content hash for container operations."""
        return hash(self.content_hash())
default_template_fields = default_template_fields or {} instance-attribute
instructions = instructions instance-attribute
name = name instance-attribute
path = path instance-attribute
__eq__(other)

Compare prompts based on their content.

Source code in src/tnh_scholar/ai_text_processing/prompts.py
326
327
328
329
330
def __eq__(self, other: object) -> bool:
    """Compare prompts based on their content."""
    if not isinstance(other, Prompt):
        return NotImplemented
    return self.content_hash() == other.content_hash()
__hash__()

Hash based on content hash for container operations.

Source code in src/tnh_scholar/ai_text_processing/prompts.py
332
333
334
def __hash__(self) -> int:
    """Hash based on content hash for container operations."""
    return hash(self.content_hash())
__init__(name, instructions, path=None, default_template_fields=None, allow_empty_vars=False)

Initialize a new Prompt instance.

Parameters:

Name Type Description Default
name str

Unique name identifying the prompt

required
instructions MarkdownStr

Jinja2 template string containing the prompt

required
default_template_fields Optional[Dict[str, str]]

Optional default values for template variables

None
allow_empty_vars bool

Whether to allow undefined template variables

False

Raises:

Type Description
ValueError

If name or instructions are empty

TemplateError

If template syntax is invalid

Source code in src/tnh_scholar/ai_text_processing/prompts.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def __init__(
    self,
    name: str,
    instructions: MarkdownStr,
    path: Optional[Path] = None,
    default_template_fields: Optional[Dict[str, str]] = None,
    allow_empty_vars: bool = False,        
) -> None:
    """
    Initialize a new Prompt instance.

    Args:
        name: Unique name identifying the prompt
        instructions: Jinja2 template string containing the prompt
        default_template_fields: Optional default values for template variables
        allow_empty_vars: Whether to allow undefined template variables

    Raises:
        ValueError: If name or instructions are empty
        TemplateError: If template syntax is invalid
    """
    if not name or not instructions:
        raise ValueError("Name and instructions must not be empty")

    # Normalize prompt name to lowercase for case-insensitive handling
    name = Prompt._normalize_name(name)

    self.name = name
    self.instructions = instructions
    self.path = path
    self.default_template_fields = default_template_fields or {}
    self._allow_empty_vars = allow_empty_vars
    self._env = self._create_environment()

    # Validate template syntax on initialization
    self._validate_template()
apply_template(field_values=None)

Apply template values to prompt instructions using Jinja2.

Values precedence (highest to lowest): 1. field_values (explicitly passed) 2. frontmatter values (from prompt file) 3. default_template_fields (prompt defaults)

Parameters:

Name Type Description Default
field_values Optional[Dict[str, str]]

Values to substitute into the template. If None, uses frontmatter/defaults.

None

Returns:

Name Type Description
str str

Rendered instructions with template values applied.

Raises:

Type Description
TemplateError

If template rendering fails

ValueError

If required template variables are missing

Source code in src/tnh_scholar/ai_text_processing/prompts.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def apply_template(self, field_values: Optional[Dict[str, str]] = None) -> str:
    """
    Apply template values to prompt instructions using Jinja2.

    Values precedence (highest to lowest):
    1. field_values (explicitly passed)
    2. frontmatter values (from prompt file)
    3. default_template_fields (prompt defaults)

    Args:
        field_values: Values to substitute into the template.
                    If None, uses frontmatter/defaults.

    Returns:
        str: Rendered instructions with template values applied.

    Raises:
        TemplateError: If template rendering fails
        ValueError: If required template variables are missing
    """
    # Get frontmatter values
    frontmatter = self.extract_frontmatter() or {}

    # Combine values with correct precedence using | operator
    template_values = self.default_template_fields | \
        frontmatter | (field_values or {})

    instructions = self.get_content_without_frontmatter()
    logger.debug(f"instructions without frontmatter:\n{instructions}")

    try:
        return self._render_template_with_values(instructions, template_values)
    except TemplateError as e:
        raise TemplateError(
            f"Template rendering failed for prompt '{self.name}': {str(e)}"
            ) from e
content_hash()

Generate a SHA-256 hash of the prompt content.

Useful for quick content comparison and change detection.

Returns:

Name Type Description
str str

Hexadecimal string of the SHA-256 hash

Source code in src/tnh_scholar/ai_text_processing/prompts.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def content_hash(self) -> str:
    """
    Generate a SHA-256 hash of the prompt content.

    Useful for quick content comparison and change detection.

    Returns:
        str: Hexadecimal string of the SHA-256 hash
    """
    content = (
        f"{self.name}{self.instructions}"
        f"{sorted(self.default_template_fields.items())}"
        )
    return hashlib.sha256(content.encode("utf-8")).hexdigest()
extract_frontmatter()

Extract and validate YAML frontmatter from markdown instructions.

Returns:

Type Description
Optional[Dict[str, Any]]

Optional[Dict]: Frontmatter data if found and valid, None otherwise

Note

Frontmatter must be at the very start of the file and properly formatted.

Source code in src/tnh_scholar/ai_text_processing/prompts.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def extract_frontmatter(self) -> Optional[Dict[str, Any]]:
    """
    Extract and validate YAML frontmatter from markdown instructions.

    Returns:
        Optional[Dict]: Frontmatter data if found and valid, None otherwise

    Note:
        Frontmatter must be at the very start of the file and properly formatted.
    """

    prompt = r"\A---\s*\n(.*?)\n---\s*(?:\n|$)"
    if match := re.match(prompt, self.instructions, re.DOTALL):
        try:
            frontmatter = yaml.safe_load(match[1])
            if frontmatter is None:
                return None
            if not isinstance(frontmatter, dict):
                logger.warning(f"Frontmatter must be a YAML dictionary: "
                               f"{frontmatter}")
                return None
            return frontmatter
        except yaml.YAMLError as e:
            logger.warning(f"Invalid YAML in frontmatter: {e}")
            return None
    return None
from_dict(data) classmethod

Create prompt instance from dictionary data.

Parameters:

Name Type Description Default
data Dict[str, Any]

Dictionary containing prompt data

required

Returns:

Name Type Description
Prompt Prompt

New prompt instance

Raises:

Type Description
ValueError

If required fields are missing

Source code in src/tnh_scholar/ai_text_processing/prompts.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Prompt":
    """
    Create prompt instance from dictionary data.

    Args:
        data: Dictionary containing prompt data

    Returns:
        Prompt: New prompt instance

    Raises:
        ValueError: If required fields are missing
    """
    required_fields = {"name", "instructions"}
    if missing_fields := required_fields - set(data.keys()):
        raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")

    return cls(
        name=Prompt._normalize_name(str(data["name"])),
        instructions=data["instructions"],
        path=None,
        default_template_fields=data.get("default_template_fields", {}),
    )
get_content_without_frontmatter()

Get markdown content with frontmatter removed.

Returns:

Name Type Description
str str

Markdown content without frontmatter

Source code in src/tnh_scholar/ai_text_processing/prompts.py
223
224
225
226
227
228
229
230
231
def get_content_without_frontmatter(self) -> str:
    """
    Get markdown content with frontmatter removed.

    Returns:
        str: Markdown content without frontmatter
    """
    prompt = r"\A---\s*\n.*?\n---\s*\n"
    return re.sub(prompt, "", self.instructions, flags=re.DOTALL)
source_bytes()

Best-effort raw bytes for prompt hashing.

Prefers hashing exact on-disk bytes including front-matter. We therefore first try to read from prompt_path. If that fails, we fall back to hashing the concatenation of known templates. In V1, only the instructions (system template) are used for rendering.

Source code in src/tnh_scholar/ai_text_processing/prompts.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def source_bytes(self) -> bytes:
    """
    Best-effort raw bytes for prompt hashing.

    Prefers hashing exact on-disk bytes including front-matter.
    We therefore first try to read from `prompt_path`. If that fails, we fall back
    to hashing the concatenation of known templates. In V1, only
    the instructions (system template) are used for rendering.
    """
    # Preferred path: use on-disk bytes when available.
    if self.path is not None:
        return self.path.read_bytes()

    # Fallback: concatenate known templates deterministically
    sys_part = self.instructions or ""
    return sys_part.encode("utf-8")
to_dict()

Convert prompt to dictionary for serialization.

Returns:

Type Description
Dict[str, Any]

Dict containing all prompt data in serializable format

Source code in src/tnh_scholar/ai_text_processing/prompts.py
288
289
290
291
292
293
294
295
296
297
298
299
def to_dict(self) -> Dict[str, Any]:
    """
    Convert prompt to dictionary for serialization.

    Returns:
        Dict containing all prompt data in serializable format
    """
    return {
        "name": self.name,
        "instructions": self.instructions,
        "default_template_fields": self.default_template_fields,
    }
update_frontmatter(new_data)

Update or add frontmatter to the markdown content.

Parameters:

Name Type Description Default
new_data Dict[str, Any]

Dictionary of frontmatter fields to update

required
Source code in src/tnh_scholar/ai_text_processing/prompts.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def update_frontmatter(self, new_data: Dict[str, Any]) -> None:
    """
    Update or add frontmatter to the markdown content.

    Args:
        new_data: Dictionary of frontmatter fields to update
    """

    current_frontmatter = self.extract_frontmatter() or {}
    updated_frontmatter = {**current_frontmatter, **new_data}

    # Create YAML string
    yaml_str = yaml.dump(
        updated_frontmatter, default_flow_style=False, allow_unicode=True
    )

    # Remove existing frontmatter if present
    content = self.get_content_without_frontmatter()

    # Combine new frontmatter with content
    self.instructions = f"---\n{yaml_str}---\n\n{content}"
PromptCatalog

Main interface for prompt management system.

Provides high-level operations: - Prompt creation and loading - Automatic versioning - Safe concurrent access - Basic history tracking - Case-insensitive prompt names (stored as lowercase)

Source code in src/tnh_scholar/ai_text_processing/prompts.py
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
class PromptCatalog:
    """
    Main interface for prompt management system.

    Provides high-level operations:
    - Prompt creation and loading
    - Automatic versioning
    - Safe concurrent access
    - Basic history tracking
    - Case-insensitive prompt names (stored as lowercase)
    """

    def __init__(self, base_path: Path):
        """
        Initialize prompt management system.

        Args:
            base_path: Base directory for prompt storage
        """
        self.base_path = Path(base_path).resolve()
        self.base_path.mkdir(parents=True, exist_ok=True)

        # Initialize subsystems
        self.repo = GitBackedRepository(self.base_path)
        self.access_manager = ConcurrentAccessManager(self.base_path / ".locks")

        logger.info(f"Initialized prompt management system at {base_path}")

    def _normalize_path(self, path: Union[str, Path]) -> Path:
        """
        Normalize a path to be absolute under the repository base path.

        Handles these cases to same result:
        - "my_file" -> <base_path>/my_file
        - "<base_path>/my_file" -> <base_path>/my_file

        Args:
            path: Input path as string or Path

        Returns:
            Path: Absolute path under base_path

        Raises:
            ValueError: If path would resolve outside repository
        """
        path = Path(path)  # ensure we have a path

        # Join with base_path as needed: always interpret relative
        # paths as relative to the repository base path. This avoids
        # incorrectly handling nested relative paths like "a/b"
        # which may not have the same parent as self.base_path.
        if not path.is_absolute():
            path = self.base_path / path

        # Safety check after resolution
        resolved = path.resolve()
        try:
            resolved.relative_to(self.base_path)
        except ValueError as e:
            raise ValueError(
                f"Path {path} resolves outside repository: {self.base_path}"
            ) from e

        return resolved

    def get_path(self, prompt_name: str) -> Optional[Path]:
        """
        Recursively search for a prompt file with the given name (case-insensitive)
        in base_path and all subdirectories.

        Args:
            prompt_name: prompt name (without extension) to search for

        Returns:
            Optional[Path]: Full path to the found prompt file, or None if not found
        """
        target = Prompt._normalize_name(prompt_name)
        with suppress(StopIteration):
            for path in self.base_path.rglob("*.md"):
                if path.is_file() and path.stem.lower() == target:
                    logger.debug(
                        f"Found prompt file for name {prompt_name} at: {path}"
                    )
                    return self._normalize_path(path)
        logger.debug(f"No prompt file found with name: {prompt_name}")
        return None

    def save(self, prompt: Prompt, subdir: Optional[Path] = None) -> Path:
        prompt_name = Prompt._normalize_name(prompt.name)
        instructions = prompt.instructions

        if subdir is None:
            path = self.base_path / f"{prompt_name}.md"
        else:
            path = self.base_path / subdir / f"{prompt_name}.md"

        path = self._normalize_path(path)

        # Check for existing prompt by case-insensitive match
        existing_path = self.get_path(prompt_name)

        try:
            # Lock on the destination path name (lowercase) to avoid races
            with self.access_manager.file_lock(path):
                # If an existing file is present but at a different case/path, rename it
                if existing_path is not None and existing_path != path:
                    path.parent.mkdir(parents=True, exist_ok=True)
                    logger.info(
                        f"Renaming existing prompt file from {existing_path} to {path} "
                        "to enforce lowercase naming."
                    )
                    existing_path.rename(path)

                write_str_to_file(path, instructions, overwrite=True)
                self.repo.update_file(path)
                logger.info(f"Prompt saved at {path}")
                return path.relative_to(self.base_path)

        except Exception as e:
            logger.error(f"Failed to save prompt {prompt_name}: {e}")
            raise

    def load(self, prompt_name: str) -> Prompt:
        """
        Load the .md prompt file by name, extract placeholders, and
        return a fully constructed Prompt object.

        Args:
            prompt_name: Name of the prompt (without .md extension).

        Returns:
            A new Prompt object whose 'instructions' is the file's text
            and whose 'template_fields' are inferred from placeholders in
            those instructions.
        """
        prompt_name = Prompt._normalize_name(prompt_name)
        # Locate the .md file; raise if missing
        path = self.get_path(prompt_name)
        if not path:
            raise FileNotFoundError(f"No prompt file named {prompt_name}.md found in prompt catalog:\n"
                                    f"{self.base_path}"
                                    )

        # Acquire lock before reading
        with self.access_manager.file_lock(path):
            instructions = read_str_from_file(path)

        instructions = MarkdownStr(instructions)

        # Create the prompt from the raw .md text (name is already lowercase)
        prompt = Prompt(name=prompt_name, instructions=instructions, path=path)

        # Check for local uncommitted changes, updating file:
        self.repo.update_file(path)

        return prompt

    def show_history(self, prompt_name: str) -> None:
        if path := self.get_path(prompt_name):
            self.repo.display_history(path)
        else:
            logger.error(f"Path to {prompt_name} not found.")
            return

    # def get_prompt_history_from_path(self, path: Path) -> List[Dict[str, Any]]:
    #     """
    #     Get version history for a prompt.

    #     Args:
    #         path: Path to prompt file

    #     Returns:
    #         List of version information
    #     """
    #     path = self._normalize_path(path)

    #     return self.repo.get_history(path)

    @classmethod
    def verify_repository(cls, base_path: Path) -> bool:
        """
        Verify repository integrity and uniqueness of prompt names.

        Performs the following checks:
        1. Validates Git repository structure.
        2. Ensures no duplicate prompt names exist.

        Args:
            base_path: Repository path to verify.

        Returns:
            bool: True if the repository is valid 
            and contains no duplicate prompt files.
        """
        try:
            # Check if it's a valid Git repository
            repo = Repo(base_path)

            # Verify basic repository structure
            basic_valid = (
                repo.head.is_valid()
                and not repo.bare
                and (base_path / ".git").is_dir()
                and (base_path / ".locks").is_dir()
            )

            if not basic_valid:
                return False

            prompt_files = list(base_path.rglob("*.md"))
            seen_names: Dict[str, Path] = {}

            for prompt_file in prompt_files:
                # Skip files in .git directory
                if ".git" in prompt_file.parts:
                    continue

                # Case-insensitive key
                key = Prompt._normalize_name(prompt_file.stem)

                if key in seen_names:
                    logger.error(
                        f"Duplicate prompt file detected (case-insensitive):\n"
                        f"  First occurrence: {seen_names[key]}\n"
                        f"  Second occurrence: {prompt_file}"
                    )
                    return False

                seen_names[key] = prompt_file

            return True

        except (InvalidGitRepositoryError, Exception) as e:
            logger.error(f"Repository verification failed: {e}")
            return False
access_manager = ConcurrentAccessManager(self.base_path / '.locks') instance-attribute
base_path = Path(base_path).resolve() instance-attribute
repo = GitBackedRepository(self.base_path) instance-attribute
__init__(base_path)

Initialize prompt management system.

Parameters:

Name Type Description Default
base_path Path

Base directory for prompt storage

required
Source code in src/tnh_scholar/ai_text_processing/prompts.py
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
def __init__(self, base_path: Path):
    """
    Initialize prompt management system.

    Args:
        base_path: Base directory for prompt storage
    """
    self.base_path = Path(base_path).resolve()
    self.base_path.mkdir(parents=True, exist_ok=True)

    # Initialize subsystems
    self.repo = GitBackedRepository(self.base_path)
    self.access_manager = ConcurrentAccessManager(self.base_path / ".locks")

    logger.info(f"Initialized prompt management system at {base_path}")
get_path(prompt_name)

Recursively search for a prompt file with the given name (case-insensitive) in base_path and all subdirectories.

Parameters:

Name Type Description Default
prompt_name str

prompt name (without extension) to search for

required

Returns:

Type Description
Optional[Path]

Optional[Path]: Full path to the found prompt file, or None if not found

Source code in src/tnh_scholar/ai_text_processing/prompts.py
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
def get_path(self, prompt_name: str) -> Optional[Path]:
    """
    Recursively search for a prompt file with the given name (case-insensitive)
    in base_path and all subdirectories.

    Args:
        prompt_name: prompt name (without extension) to search for

    Returns:
        Optional[Path]: Full path to the found prompt file, or None if not found
    """
    target = Prompt._normalize_name(prompt_name)
    with suppress(StopIteration):
        for path in self.base_path.rglob("*.md"):
            if path.is_file() and path.stem.lower() == target:
                logger.debug(
                    f"Found prompt file for name {prompt_name} at: {path}"
                )
                return self._normalize_path(path)
    logger.debug(f"No prompt file found with name: {prompt_name}")
    return None
load(prompt_name)

Load the .md prompt file by name, extract placeholders, and return a fully constructed Prompt object.

Parameters:

Name Type Description Default
prompt_name str

Name of the prompt (without .md extension).

required

Returns:

Type Description
Prompt

A new Prompt object whose 'instructions' is the file's text

Prompt

and whose 'template_fields' are inferred from placeholders in

Prompt

those instructions.

Source code in src/tnh_scholar/ai_text_processing/prompts.py
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
def load(self, prompt_name: str) -> Prompt:
    """
    Load the .md prompt file by name, extract placeholders, and
    return a fully constructed Prompt object.

    Args:
        prompt_name: Name of the prompt (without .md extension).

    Returns:
        A new Prompt object whose 'instructions' is the file's text
        and whose 'template_fields' are inferred from placeholders in
        those instructions.
    """
    prompt_name = Prompt._normalize_name(prompt_name)
    # Locate the .md file; raise if missing
    path = self.get_path(prompt_name)
    if not path:
        raise FileNotFoundError(f"No prompt file named {prompt_name}.md found in prompt catalog:\n"
                                f"{self.base_path}"
                                )

    # Acquire lock before reading
    with self.access_manager.file_lock(path):
        instructions = read_str_from_file(path)

    instructions = MarkdownStr(instructions)

    # Create the prompt from the raw .md text (name is already lowercase)
    prompt = Prompt(name=prompt_name, instructions=instructions, path=path)

    # Check for local uncommitted changes, updating file:
    self.repo.update_file(path)

    return prompt
save(prompt, subdir=None)
Source code in src/tnh_scholar/ai_text_processing/prompts.py
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
def save(self, prompt: Prompt, subdir: Optional[Path] = None) -> Path:
    prompt_name = Prompt._normalize_name(prompt.name)
    instructions = prompt.instructions

    if subdir is None:
        path = self.base_path / f"{prompt_name}.md"
    else:
        path = self.base_path / subdir / f"{prompt_name}.md"

    path = self._normalize_path(path)

    # Check for existing prompt by case-insensitive match
    existing_path = self.get_path(prompt_name)

    try:
        # Lock on the destination path name (lowercase) to avoid races
        with self.access_manager.file_lock(path):
            # If an existing file is present but at a different case/path, rename it
            if existing_path is not None and existing_path != path:
                path.parent.mkdir(parents=True, exist_ok=True)
                logger.info(
                    f"Renaming existing prompt file from {existing_path} to {path} "
                    "to enforce lowercase naming."
                )
                existing_path.rename(path)

            write_str_to_file(path, instructions, overwrite=True)
            self.repo.update_file(path)
            logger.info(f"Prompt saved at {path}")
            return path.relative_to(self.base_path)

    except Exception as e:
        logger.error(f"Failed to save prompt {prompt_name}: {e}")
        raise
show_history(prompt_name)
Source code in src/tnh_scholar/ai_text_processing/prompts.py
852
853
854
855
856
857
def show_history(self, prompt_name: str) -> None:
    if path := self.get_path(prompt_name):
        self.repo.display_history(path)
    else:
        logger.error(f"Path to {prompt_name} not found.")
        return
verify_repository(base_path) classmethod

Verify repository integrity and uniqueness of prompt names.

Performs the following checks: 1. Validates Git repository structure. 2. Ensures no duplicate prompt names exist.

Parameters:

Name Type Description Default
base_path Path

Repository path to verify.

required

Returns:

Name Type Description
bool bool

True if the repository is valid

bool

and contains no duplicate prompt files.

Source code in src/tnh_scholar/ai_text_processing/prompts.py
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
@classmethod
def verify_repository(cls, base_path: Path) -> bool:
    """
    Verify repository integrity and uniqueness of prompt names.

    Performs the following checks:
    1. Validates Git repository structure.
    2. Ensures no duplicate prompt names exist.

    Args:
        base_path: Repository path to verify.

    Returns:
        bool: True if the repository is valid 
        and contains no duplicate prompt files.
    """
    try:
        # Check if it's a valid Git repository
        repo = Repo(base_path)

        # Verify basic repository structure
        basic_valid = (
            repo.head.is_valid()
            and not repo.bare
            and (base_path / ".git").is_dir()
            and (base_path / ".locks").is_dir()
        )

        if not basic_valid:
            return False

        prompt_files = list(base_path.rglob("*.md"))
        seen_names: Dict[str, Path] = {}

        for prompt_file in prompt_files:
            # Skip files in .git directory
            if ".git" in prompt_file.parts:
                continue

            # Case-insensitive key
            key = Prompt._normalize_name(prompt_file.stem)

            if key in seen_names:
                logger.error(
                    f"Duplicate prompt file detected (case-insensitive):\n"
                    f"  First occurrence: {seen_names[key]}\n"
                    f"  Second occurrence: {prompt_file}"
                )
                return False

            seen_names[key] = prompt_file

        return True

    except (InvalidGitRepositoryError, Exception) as e:
        logger.error(f"Repository verification failed: {e}")
        return False

response_format

TEXT_SECTIONS_DESCRIPTION = 'Ordered list of logical sections for the text. The sequence of line ranges for the sections must cover every line from start to finish without any overlaps or gaps.' module-attribute
LogicalSection

Bases: BaseModel

A logically coherent section of text.

Source code in src/tnh_scholar/ai_text_processing/response_format.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class LogicalSection(BaseModel):
    """
    A logically coherent section of text.
    """

    title: str = Field(
        ...,
        description="Meaningful title for the section in the original language of the section.",
    )
    start_line: int = Field(
        ..., description="Starting line number of the section (inclusive)."
    )
    end_line: int = Field(
        ..., description="Ending line number of the section (inclusive)."
    )
end_line = Field(..., description='Ending line number of the section (inclusive).') class-attribute instance-attribute
start_line = Field(..., description='Starting line number of the section (inclusive).') class-attribute instance-attribute
title = Field(..., description='Meaningful title for the section in the original language of the section.') class-attribute instance-attribute
TextObject

Bases: BaseModel

Represents a text in any language broken into coherent logical sections.

Source code in src/tnh_scholar/ai_text_processing/response_format.py
29
30
31
32
33
34
35
class TextObject(BaseModel):
    """
    Represents a text in any language broken into coherent logical sections.
    """

    language: str = Field(..., description="ISO 639-1 language code of the text.")
    sections: List[LogicalSection] = Field(..., description=TEXT_SECTIONS_DESCRIPTION)
language = Field(..., description='ISO 639-1 language code of the text.') class-attribute instance-attribute
sections = Field(..., description=TEXT_SECTIONS_DESCRIPTION) class-attribute instance-attribute

section_processor

text_object

StorageFormatType = Union[StorageFormat, Literal['text', 'json']] module-attribute
logger = get_child_logger(__name__) module-attribute
AIResponse

Bases: BaseModel

Class for dividing large texts into AI-processable segments while maintaining broader document context.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class AIResponse(BaseModel):
    """Class for dividing large texts into AI-processable segments while
    maintaining broader document context."""
    document_summary: str = Field(
        ...,
        description="Concise, comprehensive overview of the text's content and purpose"
    )
    document_metadata: str = Field(
        ...,
        description="Available Dublin Core standard metadata in human-readable YAML format" # noqa: E501
    )
    key_concepts: str = Field(
        ...,
        description="Important terms, ideas, or references that appear throughout the text"  # noqa: E501
    )
    narrative_context: str = Field(
        ...,
        description="Concise overview of how the text develops or progresses as a whole"
    )
    language: str = Field(..., description="ISO 639-1 language code")
    sections: List[LogicalSection]
document_metadata = Field(..., description='Available Dublin Core standard metadata in human-readable YAML format') class-attribute instance-attribute
document_summary = Field(..., description="Concise, comprehensive overview of the text's content and purpose") class-attribute instance-attribute
key_concepts = Field(..., description='Important terms, ideas, or references that appear throughout the text') class-attribute instance-attribute
language = Field(..., description='ISO 639-1 language code') class-attribute instance-attribute
narrative_context = Field(..., description='Concise overview of how the text develops or progresses as a whole') class-attribute instance-attribute
sections instance-attribute
LoadConfig dataclass

Configuration for loading a TextObject.

Attributes:

Name Type Description
format StorageFormat

Storage format of the input file

source_str Optional[str]

Optional source content as string

source_file Optional[Path]

Optional path to source content file

Note

For JSON format, exactly one of source_str or source_file may be provided. Both fields are ignored for TEXT format.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
@dataclass(frozen=True)
class LoadConfig:
    """Configuration for loading a TextObject.

    Attributes:
        format: Storage format of the input file
        source_str: Optional source content as string
        source_file: Optional path to source content file

    Note:
        For JSON format, exactly one of source_str or source_file may be provided.
        Both fields are ignored for TEXT format.
    """
    format: StorageFormat = StorageFormat.TEXT
    source_str: Optional[str] = None
    source_file: Optional[Path] = None

    def __post_init__(self):
        """Validate configuration."""
        valid_source = (
            (self.source_str is None) ^ (self.source_file is None)
        )
        if self.format == StorageFormat.JSON and not valid_source:
            raise ValueError(
                "Either source_str or source_file (not both) "
                "must be set for JSON format."
            )

    def get_source_text(self) -> Optional[str]:
        """Get source content as text if provided."""
        if self.source_file is not None:
            return read_str_from_file(self.source_file)
        return self.source_str
format = StorageFormat.TEXT class-attribute instance-attribute
source_file = None class-attribute instance-attribute
source_str = None class-attribute instance-attribute
__init__(format=StorageFormat.TEXT, source_str=None, source_file=None)
__post_init__()

Validate configuration.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
49
50
51
52
53
54
55
56
57
58
def __post_init__(self):
    """Validate configuration."""
    valid_source = (
        (self.source_str is None) ^ (self.source_file is None)
    )
    if self.format == StorageFormat.JSON and not valid_source:
        raise ValueError(
            "Either source_str or source_file (not both) "
            "must be set for JSON format."
        )
get_source_text()

Get source content as text if provided.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
60
61
62
63
64
def get_source_text(self) -> Optional[str]:
    """Get source content as text if provided."""
    if self.source_file is not None:
        return read_str_from_file(self.source_file)
    return self.source_str
LogicalSection

Bases: BaseModel

Represents a contextually meaningful segment of a larger text.

Sections should preserve natural breaks in content (explicit section markers, topic shifts, argument development, narrative progression) while staying within specified size limits in order to create chunks suitable for AI processing.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class LogicalSection(BaseModel):
    """
    Represents a contextually meaningful segment of a larger text.

    Sections should preserve natural breaks in content 
    (explicit section markers, topic shifts, argument development, narrative progression) 
    while staying within specified size limits in order to create chunks suitable for AI processing.
    """  # noqa: E501
    start_line: int = Field(
        ..., 
        description="Starting line number that begins this logical segment"
    )
    title: str = Field(
        ...,
        description="Descriptive title of section's key content"
    )
start_line = Field(..., description='Starting line number that begins this logical segment') class-attribute instance-attribute
title = Field(..., description="Descriptive title of section's key content") class-attribute instance-attribute
SectionEntry

Bases: NamedTuple

Represents a section with its content during iteration.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
72
73
74
75
76
77
class SectionEntry(NamedTuple):
    """Represents a section with its content during iteration."""
    number: int         # Logical Section number (1 based index)
    title: str          # Section title 
    content: str        # Section content
    range: SectionRange # Section range
content instance-attribute
number instance-attribute
range instance-attribute
title instance-attribute
SectionObject dataclass

Represents a section of text with metadata.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
@dataclass
class SectionObject:
    """Represents a section of text with metadata."""
    title: str
    section_range: SectionRange
    metadata: Optional[Metadata] 

    @classmethod
    def from_logical_section(
        cls, 
        logical_section: LogicalSection, 
        end_line: int, 
        metadata: Optional[Metadata] = None
        ) -> "SectionObject":
        """Create a SectionObject from a LogicalSection model."""
        return cls(
            title=logical_section.title,
            section_range = SectionRange(logical_section.start_line, end_line),
            metadata = metadata 
        )
metadata instance-attribute
section_range instance-attribute
title instance-attribute
__init__(title, section_range, metadata)
from_logical_section(logical_section, end_line, metadata=None) classmethod

Create a SectionObject from a LogicalSection model.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
126
127
128
129
130
131
132
133
134
135
136
137
138
@classmethod
def from_logical_section(
    cls, 
    logical_section: LogicalSection, 
    end_line: int, 
    metadata: Optional[Metadata] = None
    ) -> "SectionObject":
    """Create a SectionObject from a LogicalSection model."""
    return cls(
        title=logical_section.title,
        section_range = SectionRange(logical_section.start_line, end_line),
        metadata = metadata 
    )
SectionRange

Bases: NamedTuple

Represents the line range of a section.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
67
68
69
70
class SectionRange(NamedTuple):
    """Represents the line range of a section."""
    start: int  # Start line (inclusive)
    end: int    # End line (Exclusive)
end instance-attribute
start instance-attribute
StorageFormat

Bases: Enum

Source code in src/tnh_scholar/ai_text_processing/text_object.py
26
27
28
class StorageFormat(Enum):
    TEXT = "text"
    JSON = "json"
JSON = 'json' class-attribute instance-attribute
TEXT = 'text' class-attribute instance-attribute
TextObject

Manages text content with section organization and metadata tracking.

TextObject serves as the core container for text processing, providing: - Line-numbered text content management - Language identification - Section organization and access - Metadata tracking including incorporated processing stages

The class allows for section boundaries through line numbering, allowing sections to be defined by start lines without explicit end lines. Subsequent sections implicitly end where the next section begins. SectionObjects are utilized to represent sections.

Attributes:

Name Type Description
num_text NumberedText

Line-numbered text content manager

language str

ISO 639-1 language code for the text content

_sections List[SectionObject]

Internal list of text sections with boundaries

_metadata Metadata

Processing and content metadata container

Example

content = NumberedText("Line 1\nLine 2\nLine 3") obj = TextObject(content, language="en")

Source code in src/tnh_scholar/ai_text_processing/text_object.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
class TextObject:
    """
    Manages text content with section organization and metadata tracking.

    TextObject serves as the core container for text processing, providing:
    - Line-numbered text content management
    - Language identification
    - Section organization and access
    - Metadata tracking including incorporated processing stages

    The class allows for section boundaries through line numbering,
    allowing sections to be defined by start lines without explicit end lines.
    Subsequent sections implicitly end where the next section begins.
    SectionObjects are utilized to represent sections.

    Attributes:
        num_text: Line-numbered text content manager
        language: ISO 639-1 language code for the text content
        _sections: Internal list of text sections with boundaries
        _metadata: Processing and content metadata container

    Example:
        >>> content = NumberedText("Line 1\\nLine 2\\nLine 3")
        >>> obj = TextObject(content, language="en")
    """
    num_text: NumberedText 
    language: str 
    _sections: List[SectionObject]
    _metadata: Metadata

    def __init__(self, 
        num_text: NumberedText, 
        language: Optional[str] = None, 
        sections: Optional[List[SectionObject]] = None,
        metadata: Optional[Metadata] = None):
        """
        Initialize a TextObject with content and optional organizing components.

        Args:
            num_text: Text content with line numbering
            language: ISO 639-1 language code. If None, auto-detected from content
            sections: Initial sections defining text organization. If None, 
                      text is considered un-sectioned.
            metadata: Initial metadata. If None, creates empty metadata container

        Note:
            Until sections are established, section-based methods will raise a value
            error if called.
        """
        self.num_text = num_text
        self.language = language or get_language_code_from_text(num_text.content)
        self._sections = sections or []
        self._metadata = metadata or Metadata()

        if sections:
            self.validate_sections()


    def __iter__(self) -> Iterator[SectionEntry]:
        """Iterate through sections, yielding full section information."""
        if not self._sections:
            raise ValueError("No Sections available.")

        for i, section in enumerate(self._sections):
            content = self.num_text.get_segment(
                section.section_range.start, 
                section.section_range.end
            )
            yield SectionEntry(
                number=i+1,
                title=section.title,
                range=section.section_range,
                content=content
            )

    def __str__(self) -> str:
        return Frontmatter.embed(self.metadata, self.content)

    @staticmethod
    def _build_section_objects(
        logical_sections: List[LogicalSection], 
        last_line: int,
        metadata: Optional[Metadata] = None
    ) -> List[SectionObject]:
        """Convert LogicalSections to SectionObjects with proper ranges."""
        section_objects = []

        for i, section in enumerate(logical_sections):
            # For each section, end is either next section's start or last line + 1
            end_line = (logical_sections[i + 1].start_line 
                    if i < len(logical_sections) - 1 
                    else last_line + 1)

            section_objects.append(
                SectionObject.from_logical_section(section, end_line, metadata)
            )

        return section_objects

    @classmethod
    def from_str(
        cls,
        text: str,
        language: Optional[str] = None,
        sections: Optional[List[SectionObject]] = None,
        metadata: Optional[Metadata] = None
    ) -> 'TextObject':
        """
        Create a TextObject from a string, extracting any frontmatter.

        Args:
            text: Input text string, potentially containing frontmatter
            language: ISO language code
            sections: List of section objects
            metadata: Optional base metadata to merge with frontmatter

        Returns:
            TextObject instance with combined metadata
        """
        # Extract any frontmatter and merge with provided metadata
        frontmatter_metadata, content = Frontmatter.extract(text)

        # Create NumberedText from content without frontmatter
        numbered_text = NumberedText(content)

        obj = cls(
            num_text=numbered_text,
            language=language,
            sections=sections,
            metadata=frontmatter_metadata
        )
        if metadata:
            obj.merge_metadata(metadata)

        return obj


    @classmethod
    def from_response(
        cls, 
        response: AIResponse,
        existing_metadata: Metadata,
        num_text: 'NumberedText'
    ) -> 'TextObject':
        """Create TextObject from AI response format."""
        # Create metadata from response
        ai_metadata = response.document_metadata
        new_metadata = Metadata({
            "ai_summary": response.document_summary,
            "ai_concepts": response.key_concepts,
            "ai_context": response.narrative_context
        })

        # Convert LogicalSections to SectionObjects
        sections = cls._build_section_objects(
            response.sections, 
            num_text.size,
        )

        text = cls(
            num_text=num_text,
            language=response.language,
            sections=sections,
            metadata=existing_metadata
        )
        text.merge_metadata(new_metadata)
        text.merge_metadata(Metadata.from_yaml(ai_metadata))
        return text

    def merge_metadata(self, new_metadata: Metadata, override=False) -> None:
        """
        Merge new metadata with existing metadata.

        For now, performs simple dict-like union (|=) but can be extended 
        to handle more complex merging logic in the future (e.g., merging 
        nested structures, handling conflicts, merging arrays).

        Args:
        new_metadata: Metadata to merge with existing metadata
        override: If True, new_metadata values override existing values
                            If False, existing values are preserved
        """
        # Currently using simple dict union
        # Future implementations might handle:
        # - Deep merging of nested structures
        # - Special handling of specific fields
        # - Array/list merging strategies
        # - Conflict resolution
        # - Metadata versioning
        if not new_metadata:
            return

        if override:
            self._metadata |= new_metadata  # new overrides existing
        else:
            self._metadata = new_metadata | self._metadata # existing values preserved

        logger.debug("Merging new metadata into TextObject")

    def update_metadata(self, **kwargs) -> None:
        """Update metadata with new key-value pairs."""
        new_metadata = Metadata(kwargs)
        self.merge_metadata(new_metadata)

    def validate_sections(self) -> None:
        """Basic validation of section integrity."""
        if not self._sections:
            raise ValueError("No sections set.")

        # Check section ordering and bounds
        for i, section in enumerate(self._sections):
            if section.section_range.start < 1:
                logger.warning(f"Section {i}: start line must be >= 1")
            if section.section_range.start > self.num_text.size:
                logger.warning(f"Section {i}: start line exceeds text length")
            if i > 0 and \
                section.section_range.start <= self._sections[i-1].section_range.start:
                logger.warning(f"Section {i}: non-sequential start line")

    def get_section_content(self, index: int) -> str:     
        if not self._sections:
            raise ValueError("No Sections available.")
        """Get content for a section."""            
        if index < 0 or index >= len(self._sections):
            raise IndexError("Section index out of range")

        section = self._sections[index]
        return self.num_text.get_segment(
            section.section_range.start, 
            section.section_range.end
        )

    def export_info(self, source_file: Optional[Path] = None) -> TextObjectInfo:
        """Export serializable state."""
        if source_file:
            source_file = source_file.resolve() # use absolute path for info

        return TextObjectInfo(
            source_file=source_file,
            language=self.language,
            sections=self.sections,
            metadata=self.metadata
        )

    @classmethod
    def from_info(
        cls, 
        info: TextObjectInfo, 
        metadata: Metadata, 
        num_text: 'NumberedText'
        ) -> 'TextObject':
        """Create TextObject from info and content."""
        text_obj = cls(
            num_text=num_text, 
            language=info.language, 
            sections=info.sections, 
            metadata=info.metadata
            )

        text_obj.merge_metadata(metadata)
        return text_obj

    @classmethod
    def from_text_file(
        cls,
        file: Path
    ) -> 'TextObject':
        text_str = read_str_from_file(file)
        return cls.from_str(text_str)

    @classmethod
    def from_section_file(
        cls, 
        section_file: Path, 
        source: Optional[str] = None
        ) -> 'TextObject':
        """
        Create TextObject from a section info file, loading content from source_file.
        Metadata is extracted from the source_file or from content.

        Args:
            section_file: Path to JSON file containing TextObjectInfo
            source: Optional source string in case no source file is found.

        Returns:
            TextObject instance

        Raises:
            ValueError: If source_file is missing from section info
            FileNotFoundError: If either section_file or source_file not found
        """
        # Check section file exists
        if not section_file.exists():
            raise FileNotFoundError(f"Section file not found: {section_file}")

        # Load and parse section info
        info = TextObjectInfo.model_validate_json(read_str_from_file(section_file))

        if not source:  # passed content always takes precedence over source_file
            # check if source file exists
            if not info.source_file:
                raise ValueError(f"No content available: no source_file specified "
                                 f"in section info: {section_file}")

            source_path = Path(info.source_file)
            if not source_path.exists():
                raise FileNotFoundError(
                    f"No content available: Source file not found: {source_path}"
                    )

            # Load source from path
            source = read_str_from_file(source_path)

        metadata, content = Frontmatter.extract(source)

        # Create TextObject
        return cls.from_info(info=info, 
                             metadata=metadata, 
                             num_text=NumberedText(content)
                             )

    def save(
        self,
        path: Path,
        output_format: StorageFormatType = StorageFormat.TEXT,
        source_file: Optional[Path] = None,
        pretty: bool = True
        ) -> None:
        """
        Save TextObject to file in specified format.

        Args:
            path: Output file path
            format: "text" for full content+metadata or "json" for serialized state
            source_file: Optional source file to record in metadata
            pretty: For JSON output, whether to pretty print
        """
        if isinstance(output_format, str):
            output_format = StorageFormat(output_format)

        if output_format == StorageFormat.TEXT:
            # Full text output with metadata as frontmatter
            write_str_to_file(path, str(self))

        elif output_format == StorageFormat.JSON:
            # Export serializable state
            info = self.export_info(source_file)
            json_str = info.model_dump_json(indent=2 if pretty else None)
            write_str_to_file(path, json_str)

    @classmethod
    def load(
        cls,
        path: Path,
        config: Optional[LoadConfig] = None
    ) -> 'TextObject':
        """
        Load TextObject from file with optional configuration.

        Args:
            path: Input file path
            config: Optional loading configuration. If not provided,
                loads directly from text file.

        Returns:
            TextObject instance

        Usage:
            # Load from text file with frontmatter
            obj = TextObject.load(Path("content.txt"))

            # Load state from JSON with source content string
            config = LoadConfig(
                format=StorageFormat.JSON,
                source_content="Text content..."
            )
            obj = TextObject.load(Path("state.json"), config)

            # Load state from JSON with source content file
            config = LoadConfig(
                format=StorageFormat.JSON,
                source_content=Path("content.txt")
            )
            obj = TextObject.load(Path("state.json"), config)
        """
        # Use default config if none provided
        config = config or LoadConfig()

        if config.format == StorageFormat.TEXT:
            return cls.from_text_file(path)

        elif config.format == StorageFormat.JSON:
            return cls.from_section_file(path, source=config.get_source_text())

        else:
            raise ValueError("Unknown load configuration format.")

    def transform(
        self,
        data_str: Optional[str] = None,
        language: Optional[str] = None, 
        metadata: Optional[Metadata] = None,
        process_metadata: Optional[ProcessMetadata] = None,
        sections: Optional[List[SectionObject]] = None
    ) -> Self:
        """Update TextObject content and metadata in place.

        Optionally modifies the object's content, language, and adds process tracking.
        Process history is maintained in metadata.

        Args:
            content: New text content
            language: New language code  
            process_tag: Identifier for the process performed
        """
        # Update potentially changed elements
        if data_str:
            self.num_text = NumberedText(data_str)
        if language:
            self.language = language
        if metadata:
            self.merge_metadata(metadata)
        if process_metadata:    
            self._metadata.add_process_info(process_metadata)
        if sections:
            self._sections = sections

        return self

    @property
    def metadata(self) -> Metadata:
        """Access to metadata dictionary."""
        return self._metadata  

    @property
    def section_count(self) -> int:
        return len(self._sections) if self._sections else 0

    @property
    def last_line_num(self) -> int:
        return self.num_text.size

    @property
    def sections(self) -> List[SectionObject]:
        """Access to sections list."""
        return self._sections or []

    @property
    def content(self) -> str:
        return self.num_text.content

    @property
    def metadata_str(self) -> str:
        return self.metadata.to_yaml()

    @property
    def numbered_content(self) -> str:
        return self.num_text.numbered_content
content property
language = language or get_language_code_from_text(num_text.content) instance-attribute
last_line_num property
metadata property

Access to metadata dictionary.

metadata_str property
num_text = num_text instance-attribute
numbered_content property
section_count property
sections property

Access to sections list.

__init__(num_text, language=None, sections=None, metadata=None)

Initialize a TextObject with content and optional organizing components.

Parameters:

Name Type Description Default
num_text NumberedText

Text content with line numbering

required
language Optional[str]

ISO 639-1 language code. If None, auto-detected from content

None
sections Optional[List[SectionObject]]

Initial sections defining text organization. If None, text is considered un-sectioned.

None
metadata Optional[Metadata]

Initial metadata. If None, creates empty metadata container

None
Note

Until sections are established, section-based methods will raise a value error if called.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def __init__(self, 
    num_text: NumberedText, 
    language: Optional[str] = None, 
    sections: Optional[List[SectionObject]] = None,
    metadata: Optional[Metadata] = None):
    """
    Initialize a TextObject with content and optional organizing components.

    Args:
        num_text: Text content with line numbering
        language: ISO 639-1 language code. If None, auto-detected from content
        sections: Initial sections defining text organization. If None, 
                  text is considered un-sectioned.
        metadata: Initial metadata. If None, creates empty metadata container

    Note:
        Until sections are established, section-based methods will raise a value
        error if called.
    """
    self.num_text = num_text
    self.language = language or get_language_code_from_text(num_text.content)
    self._sections = sections or []
    self._metadata = metadata or Metadata()

    if sections:
        self.validate_sections()
__iter__()

Iterate through sections, yielding full section information.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def __iter__(self) -> Iterator[SectionEntry]:
    """Iterate through sections, yielding full section information."""
    if not self._sections:
        raise ValueError("No Sections available.")

    for i, section in enumerate(self._sections):
        content = self.num_text.get_segment(
            section.section_range.start, 
            section.section_range.end
        )
        yield SectionEntry(
            number=i+1,
            title=section.title,
            range=section.section_range,
            content=content
        )
__str__()
Source code in src/tnh_scholar/ai_text_processing/text_object.py
230
231
def __str__(self) -> str:
    return Frontmatter.embed(self.metadata, self.content)
export_info(source_file=None)

Export serializable state.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
387
388
389
390
391
392
393
394
395
396
397
def export_info(self, source_file: Optional[Path] = None) -> TextObjectInfo:
    """Export serializable state."""
    if source_file:
        source_file = source_file.resolve() # use absolute path for info

    return TextObjectInfo(
        source_file=source_file,
        language=self.language,
        sections=self.sections,
        metadata=self.metadata
    )
from_info(info, metadata, num_text) classmethod

Create TextObject from info and content.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
@classmethod
def from_info(
    cls, 
    info: TextObjectInfo, 
    metadata: Metadata, 
    num_text: 'NumberedText'
    ) -> 'TextObject':
    """Create TextObject from info and content."""
    text_obj = cls(
        num_text=num_text, 
        language=info.language, 
        sections=info.sections, 
        metadata=info.metadata
        )

    text_obj.merge_metadata(metadata)
    return text_obj
from_response(response, existing_metadata, num_text) classmethod

Create TextObject from AI response format.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
@classmethod
def from_response(
    cls, 
    response: AIResponse,
    existing_metadata: Metadata,
    num_text: 'NumberedText'
) -> 'TextObject':
    """Create TextObject from AI response format."""
    # Create metadata from response
    ai_metadata = response.document_metadata
    new_metadata = Metadata({
        "ai_summary": response.document_summary,
        "ai_concepts": response.key_concepts,
        "ai_context": response.narrative_context
    })

    # Convert LogicalSections to SectionObjects
    sections = cls._build_section_objects(
        response.sections, 
        num_text.size,
    )

    text = cls(
        num_text=num_text,
        language=response.language,
        sections=sections,
        metadata=existing_metadata
    )
    text.merge_metadata(new_metadata)
    text.merge_metadata(Metadata.from_yaml(ai_metadata))
    return text
from_section_file(section_file, source=None) classmethod

Create TextObject from a section info file, loading content from source_file. Metadata is extracted from the source_file or from content.

Parameters:

Name Type Description Default
section_file Path

Path to JSON file containing TextObjectInfo

required
source Optional[str]

Optional source string in case no source file is found.

None

Returns:

Type Description
TextObject

TextObject instance

Raises:

Type Description
ValueError

If source_file is missing from section info

FileNotFoundError

If either section_file or source_file not found

Source code in src/tnh_scholar/ai_text_processing/text_object.py
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
@classmethod
def from_section_file(
    cls, 
    section_file: Path, 
    source: Optional[str] = None
    ) -> 'TextObject':
    """
    Create TextObject from a section info file, loading content from source_file.
    Metadata is extracted from the source_file or from content.

    Args:
        section_file: Path to JSON file containing TextObjectInfo
        source: Optional source string in case no source file is found.

    Returns:
        TextObject instance

    Raises:
        ValueError: If source_file is missing from section info
        FileNotFoundError: If either section_file or source_file not found
    """
    # Check section file exists
    if not section_file.exists():
        raise FileNotFoundError(f"Section file not found: {section_file}")

    # Load and parse section info
    info = TextObjectInfo.model_validate_json(read_str_from_file(section_file))

    if not source:  # passed content always takes precedence over source_file
        # check if source file exists
        if not info.source_file:
            raise ValueError(f"No content available: no source_file specified "
                             f"in section info: {section_file}")

        source_path = Path(info.source_file)
        if not source_path.exists():
            raise FileNotFoundError(
                f"No content available: Source file not found: {source_path}"
                )

        # Load source from path
        source = read_str_from_file(source_path)

    metadata, content = Frontmatter.extract(source)

    # Create TextObject
    return cls.from_info(info=info, 
                         metadata=metadata, 
                         num_text=NumberedText(content)
                         )
from_str(text, language=None, sections=None, metadata=None) classmethod

Create a TextObject from a string, extracting any frontmatter.

Parameters:

Name Type Description Default
text str

Input text string, potentially containing frontmatter

required
language Optional[str]

ISO language code

None
sections Optional[List[SectionObject]]

List of section objects

None
metadata Optional[Metadata]

Optional base metadata to merge with frontmatter

None

Returns:

Type Description
TextObject

TextObject instance with combined metadata

Source code in src/tnh_scholar/ai_text_processing/text_object.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
@classmethod
def from_str(
    cls,
    text: str,
    language: Optional[str] = None,
    sections: Optional[List[SectionObject]] = None,
    metadata: Optional[Metadata] = None
) -> 'TextObject':
    """
    Create a TextObject from a string, extracting any frontmatter.

    Args:
        text: Input text string, potentially containing frontmatter
        language: ISO language code
        sections: List of section objects
        metadata: Optional base metadata to merge with frontmatter

    Returns:
        TextObject instance with combined metadata
    """
    # Extract any frontmatter and merge with provided metadata
    frontmatter_metadata, content = Frontmatter.extract(text)

    # Create NumberedText from content without frontmatter
    numbered_text = NumberedText(content)

    obj = cls(
        num_text=numbered_text,
        language=language,
        sections=sections,
        metadata=frontmatter_metadata
    )
    if metadata:
        obj.merge_metadata(metadata)

    return obj
from_text_file(file) classmethod
Source code in src/tnh_scholar/ai_text_processing/text_object.py
417
418
419
420
421
422
423
@classmethod
def from_text_file(
    cls,
    file: Path
) -> 'TextObject':
    text_str = read_str_from_file(file)
    return cls.from_str(text_str)
get_section_content(index)
Source code in src/tnh_scholar/ai_text_processing/text_object.py
374
375
376
377
378
379
380
381
382
383
384
385
def get_section_content(self, index: int) -> str:     
    if not self._sections:
        raise ValueError("No Sections available.")
    """Get content for a section."""            
    if index < 0 or index >= len(self._sections):
        raise IndexError("Section index out of range")

    section = self._sections[index]
    return self.num_text.get_segment(
        section.section_range.start, 
        section.section_range.end
    )
load(path, config=None) classmethod

Load TextObject from file with optional configuration.

Parameters:

Name Type Description Default
path Path

Input file path

required
config Optional[LoadConfig]

Optional loading configuration. If not provided, loads directly from text file.

None

Returns:

Type Description
TextObject

TextObject instance

Usage
Load from text file with frontmatter

obj = TextObject.load(Path("content.txt"))

Load state from JSON with source content string

config = LoadConfig( format=StorageFormat.JSON, source_content="Text content..." ) obj = TextObject.load(Path("state.json"), config)

Load state from JSON with source content file

config = LoadConfig( format=StorageFormat.JSON, source_content=Path("content.txt") ) obj = TextObject.load(Path("state.json"), config)

Source code in src/tnh_scholar/ai_text_processing/text_object.py
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
@classmethod
def load(
    cls,
    path: Path,
    config: Optional[LoadConfig] = None
) -> 'TextObject':
    """
    Load TextObject from file with optional configuration.

    Args:
        path: Input file path
        config: Optional loading configuration. If not provided,
            loads directly from text file.

    Returns:
        TextObject instance

    Usage:
        # Load from text file with frontmatter
        obj = TextObject.load(Path("content.txt"))

        # Load state from JSON with source content string
        config = LoadConfig(
            format=StorageFormat.JSON,
            source_content="Text content..."
        )
        obj = TextObject.load(Path("state.json"), config)

        # Load state from JSON with source content file
        config = LoadConfig(
            format=StorageFormat.JSON,
            source_content=Path("content.txt")
        )
        obj = TextObject.load(Path("state.json"), config)
    """
    # Use default config if none provided
    config = config or LoadConfig()

    if config.format == StorageFormat.TEXT:
        return cls.from_text_file(path)

    elif config.format == StorageFormat.JSON:
        return cls.from_section_file(path, source=config.get_source_text())

    else:
        raise ValueError("Unknown load configuration format.")
merge_metadata(new_metadata, override=False)

Merge new metadata with existing metadata.

For now, performs simple dict-like union (|=) but can be extended to handle more complex merging logic in the future (e.g., merging nested structures, handling conflicts, merging arrays).

Args: new_metadata: Metadata to merge with existing metadata override: If True, new_metadata values override existing values If False, existing values are preserved

Source code in src/tnh_scholar/ai_text_processing/text_object.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
def merge_metadata(self, new_metadata: Metadata, override=False) -> None:
    """
    Merge new metadata with existing metadata.

    For now, performs simple dict-like union (|=) but can be extended 
    to handle more complex merging logic in the future (e.g., merging 
    nested structures, handling conflicts, merging arrays).

    Args:
    new_metadata: Metadata to merge with existing metadata
    override: If True, new_metadata values override existing values
                        If False, existing values are preserved
    """
    # Currently using simple dict union
    # Future implementations might handle:
    # - Deep merging of nested structures
    # - Special handling of specific fields
    # - Array/list merging strategies
    # - Conflict resolution
    # - Metadata versioning
    if not new_metadata:
        return

    if override:
        self._metadata |= new_metadata  # new overrides existing
    else:
        self._metadata = new_metadata | self._metadata # existing values preserved

    logger.debug("Merging new metadata into TextObject")
save(path, output_format=StorageFormat.TEXT, source_file=None, pretty=True)

Save TextObject to file in specified format.

Parameters:

Name Type Description Default
path Path

Output file path

required
format

"text" for full content+metadata or "json" for serialized state

required
source_file Optional[Path]

Optional source file to record in metadata

None
pretty bool

For JSON output, whether to pretty print

True
Source code in src/tnh_scholar/ai_text_processing/text_object.py
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
def save(
    self,
    path: Path,
    output_format: StorageFormatType = StorageFormat.TEXT,
    source_file: Optional[Path] = None,
    pretty: bool = True
    ) -> None:
    """
    Save TextObject to file in specified format.

    Args:
        path: Output file path
        format: "text" for full content+metadata or "json" for serialized state
        source_file: Optional source file to record in metadata
        pretty: For JSON output, whether to pretty print
    """
    if isinstance(output_format, str):
        output_format = StorageFormat(output_format)

    if output_format == StorageFormat.TEXT:
        # Full text output with metadata as frontmatter
        write_str_to_file(path, str(self))

    elif output_format == StorageFormat.JSON:
        # Export serializable state
        info = self.export_info(source_file)
        json_str = info.model_dump_json(indent=2 if pretty else None)
        write_str_to_file(path, json_str)
transform(data_str=None, language=None, metadata=None, process_metadata=None, sections=None)

Update TextObject content and metadata in place.

Optionally modifies the object's content, language, and adds process tracking. Process history is maintained in metadata.

Parameters:

Name Type Description Default
content

New text content

required
language Optional[str]

New language code

None
process_tag

Identifier for the process performed

required
Source code in src/tnh_scholar/ai_text_processing/text_object.py
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
def transform(
    self,
    data_str: Optional[str] = None,
    language: Optional[str] = None, 
    metadata: Optional[Metadata] = None,
    process_metadata: Optional[ProcessMetadata] = None,
    sections: Optional[List[SectionObject]] = None
) -> Self:
    """Update TextObject content and metadata in place.

    Optionally modifies the object's content, language, and adds process tracking.
    Process history is maintained in metadata.

    Args:
        content: New text content
        language: New language code  
        process_tag: Identifier for the process performed
    """
    # Update potentially changed elements
    if data_str:
        self.num_text = NumberedText(data_str)
    if language:
        self.language = language
    if metadata:
        self.merge_metadata(metadata)
    if process_metadata:    
        self._metadata.add_process_info(process_metadata)
    if sections:
        self._sections = sections

    return self
update_metadata(**kwargs)

Update metadata with new key-value pairs.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
354
355
356
357
def update_metadata(self, **kwargs) -> None:
    """Update metadata with new key-value pairs."""
    new_metadata = Metadata(kwargs)
    self.merge_metadata(new_metadata)
validate_sections()

Basic validation of section integrity.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def validate_sections(self) -> None:
    """Basic validation of section integrity."""
    if not self._sections:
        raise ValueError("No sections set.")

    # Check section ordering and bounds
    for i, section in enumerate(self._sections):
        if section.section_range.start < 1:
            logger.warning(f"Section {i}: start line must be >= 1")
        if section.section_range.start > self.num_text.size:
            logger.warning(f"Section {i}: start line exceeds text length")
        if i > 0 and \
            section.section_range.start <= self._sections[i-1].section_range.start:
            logger.warning(f"Section {i}: non-sequential start line")
TextObjectInfo

Bases: BaseModel

Serializable information about a text and its sections.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
141
142
143
144
145
146
147
148
149
150
151
152
153
class TextObjectInfo(BaseModel):
    """Serializable information about a text and its sections."""
    source_file: Optional[Path] = None  # Original text file path
    language: str
    sections: List[SectionObject]
    metadata: Metadata

    def model_post_init(self, __context: Any) -> None:
        """Ensure metadata is always a Metadata instance after initialization."""
        if isinstance(self.metadata, dict):
            self.metadata = Metadata(self.metadata)
        elif not isinstance(self.metadata, Metadata):
            raise ValueError(f"Unexpected type for metadata: {type(self.metadata)}")
language instance-attribute
metadata instance-attribute
sections instance-attribute
source_file = None class-attribute instance-attribute
model_post_init(__context)

Ensure metadata is always a Metadata instance after initialization.

Source code in src/tnh_scholar/ai_text_processing/text_object.py
148
149
150
151
152
153
def model_post_init(self, __context: Any) -> None:
    """Ensure metadata is always a Metadata instance after initialization."""
    if isinstance(self.metadata, dict):
        self.metadata = Metadata(self.metadata)
    elif not isinstance(self.metadata, Metadata):
        raise ValueError(f"Unexpected type for metadata: {type(self.metadata)}")

typing

ProcessorResult = Union[str, ResponseFormat] module-attribute
ResponseFormat = TypeVar('ResponseFormat', bound=BaseModel) module-attribute

audio_processing

__all__ = ['DiarizationConfig', 'detect_nonsilent', 'detect_whisper_boundaries', 'split_audio', 'split_audio_at_boundaries'] module-attribute

DiarizationConfig

Bases: BaseSettings

Source code in src/tnh_scholar/audio_processing/diarization/config.py
148
149
150
151
152
153
154
155
156
157
158
159
class DiarizationConfig(BaseSettings):
    model_config = SettingsConfigDict(
        env_file=".env",
        env_file_encoding="utf-8",
        case_sensitive = False,
        env_prefix = "DIARIZATION_",
        extra="ignore",
    )
    speaker: SpeakerConfig = SpeakerConfig()
    chunk: ChunkConfig = ChunkConfig()
    language: LanguageConfig = LanguageConfig()
    mapping: MappingPolicy = MappingPolicy()
chunk = ChunkConfig() class-attribute instance-attribute
language = LanguageConfig() class-attribute instance-attribute
mapping = MappingPolicy() class-attribute instance-attribute
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', case_sensitive=False, env_prefix='DIARIZATION_', extra='ignore') class-attribute instance-attribute
speaker = SpeakerConfig() class-attribute instance-attribute

detect_whisper_boundaries(audio_file, model_size='tiny', language=None)

Detect sentence boundaries using a Whisper model.

Parameters:

Name Type Description Default
audio_file Path

Path to the audio file.

required
model_size str

Whisper model size.

'tiny'
language str

Language to force for transcription (e.g. 'en', 'vi'), or None for auto.

None

Returns:

Type Description
List[Boundary]

List[Boundary]: A list of sentence boundaries with text.

Example

boundaries = detect_whisper_boundaries(Path("my_audio.mp3"), model_size="tiny") for b in boundaries: ... print(b.start, b.end, b.text)

Source code in src/tnh_scholar/audio_processing/audio_legacy.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def detect_whisper_boundaries(
    audio_file: Path, model_size: str = "tiny", language: str = None
) -> List[Boundary]:
    """
    Detect sentence boundaries using a Whisper model.

    Args:
        audio_file (Path): Path to the audio file.
        model_size (str): Whisper model size.
        language (str): Language to force for transcription (e.g. 'en', 'vi'), or None for auto.

    Returns:
        List[Boundary]: A list of sentence boundaries with text.

    Example:
        >>> boundaries = detect_whisper_boundaries(Path("my_audio.mp3"), model_size="tiny")
        >>> for b in boundaries:
        ...     print(b.start, b.end, b.text)
    """

    os.environ["KMP_WARNINGS"] = "0"  # Turn of OMP warning message

    # Load model
    logger.info("Loading Whisper model...")
    model = load_whisper_model(model_size)
    logger.info(f"Model '{model_size}' loaded.")

    if language:
        logger.info(f"Language for boundaries set to '{language}'")
    else:
        logger.info("Language not set. Autodetect will be used in Whisper model.")

    # with TimeProgress(expected_time=expected_time, desc="Generating transcription boundaries"):
    boundary_transcription = whisper_model_transcribe(
        model,
        str(audio_file),
        task="transcribe",
        word_timestamps=True,
        language=language,
        verbose=False,
    )

    sentence_boundaries = [
        Boundary(start=segment["start"], end=segment["end"], text=segment["text"])
        for segment in boundary_transcription["segments"]
    ]
    return sentence_boundaries, boundary_transcription

split_audio(audio_file, method='whisper', output_dir=None, model_size='tiny', language=None, min_silence_len=MIN_SILENCE_LENGTH, silence_thresh=SILENCE_DBFS_THRESHOLD, max_duration=MAX_DURATION)

High-level function to split an audio file into chunks based on a chosen method.

Parameters:

Name Type Description Default
audio_file Path

The input audio file.

required
method str

Splitting method, "silence" or "whisper".

'whisper'
output_dir Path

Directory to store output.

None
model_size str

Whisper model size if method='whisper'.

'tiny'
language str

Language for whisper transcription if method='whisper'.

None
min_silence_len int

For silence-based detection, min silence length in ms.

MIN_SILENCE_LENGTH
silence_thresh int

Silence threshold in dBFS.

SILENCE_DBFS_THRESHOLD
max_duration_s int

Max chunk length in seconds.

required
max_duration_ms int

Max chunk length in ms (for silence detection combination).

required

Returns:

Name Type Description
Path Path

Directory containing the resulting chunks.

Example
Split using silence detection

split_audio(Path("my_audio.mp3"), method="silence")

Split using whisper-based sentence boundaries

split_audio(Path("my_audio.mp3"), method="whisper", model_size="base", language="en")

Source code in src/tnh_scholar/audio_processing/audio_legacy.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def split_audio(
    audio_file: Path,
    method: str = "whisper",
    output_dir: Optional[Path] = None,
    model_size: str = "tiny",
    language: str = None,
    min_silence_len: int = MIN_SILENCE_LENGTH,
    silence_thresh: int = SILENCE_DBFS_THRESHOLD,
    max_duration: int = MAX_DURATION,
) -> Path:
    """
    High-level function to split an audio file into chunks based on a chosen method.

    Args:
        audio_file (Path): The input audio file.
        method (str): Splitting method, "silence" or "whisper".
        output_dir (Path): Directory to store output.
        model_size (str): Whisper model size if method='whisper'.
        language (str): Language for whisper transcription if method='whisper'.
        min_silence_len (int): For silence-based detection, min silence length in ms.
        silence_thresh (int): Silence threshold in dBFS.
        max_duration_s (int): Max chunk length in seconds.
        max_duration_ms (int): Max chunk length in ms (for silence detection combination).

    Returns:
        Path: Directory containing the resulting chunks.

    Example:
        >>> # Split using silence detection
        >>> split_audio(Path("my_audio.mp3"), method="silence")

        >>> # Split using whisper-based sentence boundaries
        >>> split_audio(Path("my_audio.mp3"), method="whisper", model_size="base", language="en")
    """

    logger.info(f"Splitting audio with max_duration={max_duration} seconds")

    if method == "whisper":
        boundaries, _ = detect_whisper_boundaries(
            audio_file, model_size=model_size, language=language
        )

    elif method == "silence":
        max_duration_ms = (
            max_duration * 1000
        )  # convert duration in seconds to milliseconds
        boundaries = detect_silence_boundaries(
            audio_file,
            min_silence_len=min_silence_len,
            silence_thresh=silence_thresh,
            max_duration=max_duration_ms,
        )
    else:
        raise ValueError(f"Unknown method: {method}. Must be 'silence' or 'whisper'.")

    # delete all files in the output_dir (this is useful for reprocessing)

    return split_audio_at_boundaries(
        audio_file, boundaries, output_dir=output_dir, max_duration=max_duration
    )

split_audio_at_boundaries(audio_file, boundaries, output_dir=None, max_duration=MAX_DURATION)

Split the audio file into chunks based on provided boundaries, ensuring all audio is included and boundaries align with the start of Whisper segments.

Parameters:

Name Type Description Default
audio_file Path

The input audio file.

required
boundaries List[Boundary]

Detected boundaries.

required
output_dir Path

Directory to store the resulting chunks.

None
max_duration int

Maximum chunk length in seconds.

MAX_DURATION

Returns:

Name Type Description
Path Path

Directory containing the chunked audio files.

Example

boundaries = [Boundary(34.02, 37.26, "..."), Boundary(38.0, 41.18, "...")] out_dir = split_audio_at_boundaries(Path("my_audio.mp3"), boundaries)

Source code in src/tnh_scholar/audio_processing/audio_legacy.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def split_audio_at_boundaries(
    audio_file: Path,
    boundaries: List[Boundary],
    output_dir: Path = None,
    max_duration: int = MAX_DURATION,
) -> Path:
    """
    Split the audio file into chunks based on provided boundaries, ensuring all audio is included
    and boundaries align with the start of Whisper segments.

    Args:
        audio_file (Path): The input audio file.
        boundaries (List[Boundary]): Detected boundaries.
        output_dir (Path): Directory to store the resulting chunks.
        max_duration (int): Maximum chunk length in seconds.

    Returns:
        Path: Directory containing the chunked audio files.

    Example:
        >>> boundaries = [Boundary(34.02, 37.26, "..."), Boundary(38.0, 41.18, "...")]
        >>> out_dir = split_audio_at_boundaries(Path("my_audio.mp3"), boundaries)
    """
    logger.info(f"Splitting audio with max_duration={max_duration} seconds")

    # Load the audio file
    audio = AudioSegment.from_file(audio_file)

    # Create output directory based on filename
    if output_dir is None:
        output_dir = audio_file.parent / f"{audio_file.stem}_chunks"
    output_dir.mkdir(parents=True, exist_ok=True)

    # Clean up the output directory
    for file in output_dir.iterdir():
        if file.is_file():
            logger.info(f"Deleting existing file: {file}")
            file.unlink()

    chunk_start = 0  # Start time for the first chunk in ms
    chunk_count = 1
    current_chunk = AudioSegment.empty()

    for idx, boundary in enumerate(boundaries):
        segment_start_ms = int(boundary.start * 1000)
        if idx + 1 < len(boundaries):
            segment_end_ms = int(
                boundaries[idx + 1].start * 1000
            )  # Next boundary's start
        else:
            segment_end_ms = len(audio)  # End of the audio for the last boundary

        # Adjust for the first segment starting at 0
        if idx == 0 and segment_start_ms > 0:
            segment_start_ms = 0  # Ensure we include the very beginning of the audio

        segment = audio[segment_start_ms:segment_end_ms]

        logger.debug(
            f"Boundary index: {idx}, segment_start: {segment_start_ms / 1000}, segment_end: {segment_end_ms / 1000}, duration: {segment.duration_seconds}"
        )
        logger.debug(f"Current chunk Duration (s): {current_chunk.duration_seconds}")

        if len(current_chunk) + len(segment) <= max_duration * 1000:
            # Add segment to the current chunk
            current_chunk += segment
        else:
            # Export current chunk
            chunk_path = output_dir / f"chunk_{chunk_count}.mp3"
            current_chunk.export(chunk_path, format="mp3")
            logger.info(f"Exported: {chunk_path}")
            chunk_count += 1

            # Start a new chunk with the current segment
            current_chunk = segment

    # Export the final chunk if any audio remains
    if len(current_chunk) > 0:
        chunk_path = output_dir / f"chunk_{chunk_count}.mp3"
        current_chunk.export(chunk_path, format="mp3")
        logger.info(f"Exported: {chunk_path}")

    return output_dir

audio_legacy

EXPECTED_TIME_FACTOR = 0.45 module-attribute
MAX_DURATION = 10 * 60 module-attribute
MAX_DURATION_MS = 10 * 60 * 1000 module-attribute
MAX_INT16 = 32768.0 module-attribute
MIN_SILENCE_LENGTH = 1000 module-attribute
SEEK_LENGTH = 50 module-attribute
SILENCE_DBFS_THRESHOLD = -30 module-attribute
logger = get_child_logger('audio_processing') module-attribute
Boundary dataclass

A data structure representing a detected audio boundary.

Attributes:

Name Type Description
start float

Start time of the segment in seconds.

end float

End time of the segment in seconds.

text str

Associated text (empty if silence-based).

Example

b = Boundary(start=0.0, end=30.0, text="Hello world") b.start, b.end, b.text (0.0, 30.0, 'Hello world')

Source code in src/tnh_scholar/audio_processing/audio_legacy.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@dataclass
class Boundary:
    """A data structure representing a detected audio boundary.

    Attributes:
        start (float): Start time of the segment in seconds.
        end (float): End time of the segment in seconds.
        text (str): Associated text (empty if silence-based).

    Example:
        >>> b = Boundary(start=0.0, end=30.0, text="Hello world")
        >>> b.start, b.end, b.text
        (0.0, 30.0, 'Hello world')
    """

    start: float
    end: float
    text: str = ""
end instance-attribute
start instance-attribute
text = '' class-attribute instance-attribute
__init__(start, end, text='')
audio_to_numpy(audio_segment)

Convert an AudioSegment object to a NumPy array suitable for Whisper.

Parameters:

Name Type Description Default
audio_segment AudioSegment

The input audio segment to convert.

required

Returns:

Type Description
ndarray

np.ndarray: A mono-channel NumPy array normalized to the range [-1, 1].

Example

audio = AudioSegment.from_file("example.mp3") audio_numpy = audio_to_numpy(audio)

Source code in src/tnh_scholar/audio_processing/audio_legacy.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
def audio_to_numpy(audio_segment: AudioSegment) -> np.ndarray:
    """
    Convert an AudioSegment object to a NumPy array suitable for Whisper.

    Args:
        audio_segment (AudioSegment): The input audio segment to convert.

    Returns:
        np.ndarray: A mono-channel NumPy array normalized to the range [-1, 1].

    Example:
        >>> audio = AudioSegment.from_file("example.mp3")
        >>> audio_numpy = audio_to_numpy(audio)
    """
    # Convert the audio segment to raw sample data
    raw_data = np.array(audio_segment.get_array_of_samples()).astype(np.float32)

    # Normalize data to the range [-1, 1]
    raw_data /= MAX_INT16

    # Ensure mono-channel (use first channel if stereo)
    if audio_segment.channels > 1:
        raw_data = raw_data.reshape(-1, audio_segment.channels)[:, 0]

    return raw_data
detect_silence_boundaries(audio_file, min_silence_len=MIN_SILENCE_LENGTH, silence_thresh=SILENCE_DBFS_THRESHOLD, max_duration=MAX_DURATION_MS)

Detect boundaries (start/end times) based on silence detection.

Parameters:

Name Type Description Default
audio_file Path

Path to the audio file.

required
min_silence_len int

Minimum silence length to consider for splitting (ms).

MIN_SILENCE_LENGTH
silence_thresh int

Silence threshold in dBFS.

SILENCE_DBFS_THRESHOLD
max_duration int

Maximum duration of any segment (ms).

MAX_DURATION_MS

Returns:

Type Description
Tuple[List[Boundary], Dict]

List[Boundary]: A list of boundaries with empty text.

Example

boundaries = detect_silence_boundaries(Path("my_audio.mp3")) for b in boundaries: ... print(b.start, b.end)

Source code in src/tnh_scholar/audio_processing/audio_legacy.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def detect_silence_boundaries(
    audio_file: Path,
    min_silence_len: int = MIN_SILENCE_LENGTH,
    silence_thresh: int = SILENCE_DBFS_THRESHOLD,
    max_duration: int = MAX_DURATION_MS,
) -> Tuple[List[Boundary], Dict]:
    """
    Detect boundaries (start/end times) based on silence detection.

    Args:
        audio_file (Path): Path to the audio file.
        min_silence_len (int): Minimum silence length to consider for splitting (ms).
        silence_thresh (int): Silence threshold in dBFS.
        max_duration (int): Maximum duration of any segment (ms).

    Returns:
        List[Boundary]: A list of boundaries with empty text.

    Example:
        >>> boundaries = detect_silence_boundaries(Path("my_audio.mp3"))
        >>> for b in boundaries:
        ...     print(b.start, b.end)
    """
    logger.debug(
        f"Detecting silence boundaries with min_silence={min_silence_len}, silence_thresh={silence_thresh}"
    )

    audio = AudioSegment.from_file(audio_file)
    nonsilent_ranges = detect_nonsilent(
        audio,
        min_silence_len=min_silence_len,
        silence_thresh=silence_thresh,
        seek_step=SEEK_LENGTH,
    )

    # Combine ranges to enforce max_duration
    if not nonsilent_ranges:
        # If no nonsilent segments found, return entire file as one boundary
        duration_s = len(audio) / 1000.0
        return [Boundary(start=0.0, end=duration_s, text="")]

    combined_ranges = []
    current_start, current_end = nonsilent_ranges[0]
    for start, end in nonsilent_ranges[1:]:
        if (current_end - current_start) + (end - start) <= max_duration:
            # Extend the current segment
            current_end = end
        else:
            combined_ranges.append((current_start, current_end))
            current_start, current_end = start, end
    combined_ranges.append((current_start, current_end))

    return [
        Boundary(start=start_ms / 1000.0, end=end_ms / 1000.0, text="")
        for start_ms, end_ms in combined_ranges
    ]
detect_whisper_boundaries(audio_file, model_size='tiny', language=None)

Detect sentence boundaries using a Whisper model.

Parameters:

Name Type Description Default
audio_file Path

Path to the audio file.

required
model_size str

Whisper model size.

'tiny'
language str

Language to force for transcription (e.g. 'en', 'vi'), or None for auto.

None

Returns:

Type Description
List[Boundary]

List[Boundary]: A list of sentence boundaries with text.

Example

boundaries = detect_whisper_boundaries(Path("my_audio.mp3"), model_size="tiny") for b in boundaries: ... print(b.start, b.end, b.text)

Source code in src/tnh_scholar/audio_processing/audio_legacy.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def detect_whisper_boundaries(
    audio_file: Path, model_size: str = "tiny", language: str = None
) -> List[Boundary]:
    """
    Detect sentence boundaries using a Whisper model.

    Args:
        audio_file (Path): Path to the audio file.
        model_size (str): Whisper model size.
        language (str): Language to force for transcription (e.g. 'en', 'vi'), or None for auto.

    Returns:
        List[Boundary]: A list of sentence boundaries with text.

    Example:
        >>> boundaries = detect_whisper_boundaries(Path("my_audio.mp3"), model_size="tiny")
        >>> for b in boundaries:
        ...     print(b.start, b.end, b.text)
    """

    os.environ["KMP_WARNINGS"] = "0"  # Turn of OMP warning message

    # Load model
    logger.info("Loading Whisper model...")
    model = load_whisper_model(model_size)
    logger.info(f"Model '{model_size}' loaded.")

    if language:
        logger.info(f"Language for boundaries set to '{language}'")
    else:
        logger.info("Language not set. Autodetect will be used in Whisper model.")

    # with TimeProgress(expected_time=expected_time, desc="Generating transcription boundaries"):
    boundary_transcription = whisper_model_transcribe(
        model,
        str(audio_file),
        task="transcribe",
        word_timestamps=True,
        language=language,
        verbose=False,
    )

    sentence_boundaries = [
        Boundary(start=segment["start"], end=segment["end"], text=segment["text"])
        for segment in boundary_transcription["segments"]
    ]
    return sentence_boundaries, boundary_transcription
split_audio(audio_file, method='whisper', output_dir=None, model_size='tiny', language=None, min_silence_len=MIN_SILENCE_LENGTH, silence_thresh=SILENCE_DBFS_THRESHOLD, max_duration=MAX_DURATION)

High-level function to split an audio file into chunks based on a chosen method.

Parameters:

Name Type Description Default
audio_file Path

The input audio file.

required
method str

Splitting method, "silence" or "whisper".

'whisper'
output_dir Path

Directory to store output.

None
model_size str

Whisper model size if method='whisper'.

'tiny'
language str

Language for whisper transcription if method='whisper'.

None
min_silence_len int

For silence-based detection, min silence length in ms.

MIN_SILENCE_LENGTH
silence_thresh int

Silence threshold in dBFS.

SILENCE_DBFS_THRESHOLD
max_duration_s int

Max chunk length in seconds.

required
max_duration_ms int

Max chunk length in ms (for silence detection combination).

required

Returns:

Name Type Description
Path Path

Directory containing the resulting chunks.

Example
Split using silence detection

split_audio(Path("my_audio.mp3"), method="silence")

Split using whisper-based sentence boundaries

split_audio(Path("my_audio.mp3"), method="whisper", model_size="base", language="en")

Source code in src/tnh_scholar/audio_processing/audio_legacy.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def split_audio(
    audio_file: Path,
    method: str = "whisper",
    output_dir: Optional[Path] = None,
    model_size: str = "tiny",
    language: str = None,
    min_silence_len: int = MIN_SILENCE_LENGTH,
    silence_thresh: int = SILENCE_DBFS_THRESHOLD,
    max_duration: int = MAX_DURATION,
) -> Path:
    """
    High-level function to split an audio file into chunks based on a chosen method.

    Args:
        audio_file (Path): The input audio file.
        method (str): Splitting method, "silence" or "whisper".
        output_dir (Path): Directory to store output.
        model_size (str): Whisper model size if method='whisper'.
        language (str): Language for whisper transcription if method='whisper'.
        min_silence_len (int): For silence-based detection, min silence length in ms.
        silence_thresh (int): Silence threshold in dBFS.
        max_duration_s (int): Max chunk length in seconds.
        max_duration_ms (int): Max chunk length in ms (for silence detection combination).

    Returns:
        Path: Directory containing the resulting chunks.

    Example:
        >>> # Split using silence detection
        >>> split_audio(Path("my_audio.mp3"), method="silence")

        >>> # Split using whisper-based sentence boundaries
        >>> split_audio(Path("my_audio.mp3"), method="whisper", model_size="base", language="en")
    """

    logger.info(f"Splitting audio with max_duration={max_duration} seconds")

    if method == "whisper":
        boundaries, _ = detect_whisper_boundaries(
            audio_file, model_size=model_size, language=language
        )

    elif method == "silence":
        max_duration_ms = (
            max_duration * 1000
        )  # convert duration in seconds to milliseconds
        boundaries = detect_silence_boundaries(
            audio_file,
            min_silence_len=min_silence_len,
            silence_thresh=silence_thresh,
            max_duration=max_duration_ms,
        )
    else:
        raise ValueError(f"Unknown method: {method}. Must be 'silence' or 'whisper'.")

    # delete all files in the output_dir (this is useful for reprocessing)

    return split_audio_at_boundaries(
        audio_file, boundaries, output_dir=output_dir, max_duration=max_duration
    )
split_audio_at_boundaries(audio_file, boundaries, output_dir=None, max_duration=MAX_DURATION)

Split the audio file into chunks based on provided boundaries, ensuring all audio is included and boundaries align with the start of Whisper segments.

Parameters:

Name Type Description Default
audio_file Path

The input audio file.

required
boundaries List[Boundary]

Detected boundaries.

required
output_dir Path

Directory to store the resulting chunks.

None
max_duration int

Maximum chunk length in seconds.

MAX_DURATION

Returns:

Name Type Description
Path Path

Directory containing the chunked audio files.

Example

boundaries = [Boundary(34.02, 37.26, "..."), Boundary(38.0, 41.18, "...")] out_dir = split_audio_at_boundaries(Path("my_audio.mp3"), boundaries)

Source code in src/tnh_scholar/audio_processing/audio_legacy.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def split_audio_at_boundaries(
    audio_file: Path,
    boundaries: List[Boundary],
    output_dir: Path = None,
    max_duration: int = MAX_DURATION,
) -> Path:
    """
    Split the audio file into chunks based on provided boundaries, ensuring all audio is included
    and boundaries align with the start of Whisper segments.

    Args:
        audio_file (Path): The input audio file.
        boundaries (List[Boundary]): Detected boundaries.
        output_dir (Path): Directory to store the resulting chunks.
        max_duration (int): Maximum chunk length in seconds.

    Returns:
        Path: Directory containing the chunked audio files.

    Example:
        >>> boundaries = [Boundary(34.02, 37.26, "..."), Boundary(38.0, 41.18, "...")]
        >>> out_dir = split_audio_at_boundaries(Path("my_audio.mp3"), boundaries)
    """
    logger.info(f"Splitting audio with max_duration={max_duration} seconds")

    # Load the audio file
    audio = AudioSegment.from_file(audio_file)

    # Create output directory based on filename
    if output_dir is None:
        output_dir = audio_file.parent / f"{audio_file.stem}_chunks"
    output_dir.mkdir(parents=True, exist_ok=True)

    # Clean up the output directory
    for file in output_dir.iterdir():
        if file.is_file():
            logger.info(f"Deleting existing file: {file}")
            file.unlink()

    chunk_start = 0  # Start time for the first chunk in ms
    chunk_count = 1
    current_chunk = AudioSegment.empty()

    for idx, boundary in enumerate(boundaries):
        segment_start_ms = int(boundary.start * 1000)
        if idx + 1 < len(boundaries):
            segment_end_ms = int(
                boundaries[idx + 1].start * 1000
            )  # Next boundary's start
        else:
            segment_end_ms = len(audio)  # End of the audio for the last boundary

        # Adjust for the first segment starting at 0
        if idx == 0 and segment_start_ms > 0:
            segment_start_ms = 0  # Ensure we include the very beginning of the audio

        segment = audio[segment_start_ms:segment_end_ms]

        logger.debug(
            f"Boundary index: {idx}, segment_start: {segment_start_ms / 1000}, segment_end: {segment_end_ms / 1000}, duration: {segment.duration_seconds}"
        )
        logger.debug(f"Current chunk Duration (s): {current_chunk.duration_seconds}")

        if len(current_chunk) + len(segment) <= max_duration * 1000:
            # Add segment to the current chunk
            current_chunk += segment
        else:
            # Export current chunk
            chunk_path = output_dir / f"chunk_{chunk_count}.mp3"
            current_chunk.export(chunk_path, format="mp3")
            logger.info(f"Exported: {chunk_path}")
            chunk_count += 1

            # Start a new chunk with the current segment
            current_chunk = segment

    # Export the final chunk if any audio remains
    if len(current_chunk) > 0:
        chunk_path = output_dir / f"chunk_{chunk_count}.mp3"
        current_chunk.export(chunk_path, format="mp3")
        logger.info(f"Exported: {chunk_path}")

    return output_dir
whisper_model_transcribe(model, input_source, *args, **kwargs)

Wrapper around model.transcribe that suppresses the known 'FP16 is not supported on CPU; using FP32 instead' UserWarning and redirects unwanted 'OMP' messages to prevent interference.

This function accepts all args and kwargs that model.transcribe normally does, and supports input sources as file paths (str or Path) or in-memory audio arrays.

Parameters:

Name Type Description Default
model Any

The Whisper model instance.

required
input_source Union[str, Path, ndarray]

Input audio file path, URL, or in-memory audio array.

required
*args

Additional positional arguments for model.transcribe.

()
**kwargs

Additional keyword arguments for model.transcribe.

{}

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: Transcription result from model.transcribe.

Example
Using a file path

result = whisper_model_transcribe(my_model, "sample_audio.mp3", verbose=True)

Using an audio array

result = whisper_model_transcribe(my_model, audio_array, language="en")

Source code in src/tnh_scholar/audio_processing/audio_legacy.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
def whisper_model_transcribe(
    model: Any,
    input_source: Any,
    *args,
    **kwargs,
) -> Dict[str, Any]:
    """
    Wrapper around model.transcribe that suppresses the known
    'FP16 is not supported on CPU; using FP32 instead' UserWarning
    and redirects unwanted 'OMP' messages to prevent interference.

    This function accepts all args and kwargs that model.transcribe normally does,
    and supports input sources as file paths (str or Path) or in-memory audio arrays.

    Parameters:
        model (Any): The Whisper model instance.
        input_source (Union[str, Path, np.ndarray]): Input audio file path, URL, or in-memory audio array.
        *args: Additional positional arguments for model.transcribe.
        **kwargs: Additional keyword arguments for model.transcribe.

    Returns:
        Dict[str, Any]: Transcription result from model.transcribe.

    Example:
        # Using a file path
        result = whisper_model_transcribe(my_model, "sample_audio.mp3", verbose=True)

        # Using an audio array
        result = whisper_model_transcribe(my_model, audio_array, language="en")
    """

    # class StdoutFilter(io.StringIO):
    #     def __init__(self, original_stdout):
    #         super().__init__()
    #         self.original_stdout = original_stdout

    #     def write(self, message):
    #         # Suppress specific messages like 'OMP:' while allowing others
    #         if "OMP:" not in message:
    #             self.original_stdout.write(message)

    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore",
            message="FP16 is not supported on CPU; using FP32 instead",
            category=UserWarning,
        )

        # Redirect stdout to suppress OMP messages
        # original_stdout = sys.stdout
        # sys.stdout = filtered_stdout

        try:
            # Convert Path to str if needed
            if isinstance(input_source, Path):
                input_source = str(input_source)

            # Call the original transcribe function
            return model.transcribe(input_source, *args, **kwargs)
        finally:
            # Restore original stdout
            # sys.stdout = original_stdout
            pass

diarization

__all__ = ['DiarizationProcessor', 'diarize', 'diarize_to_file', 'DiarizationParams', 'PyannoteClient', 'PyannoteConfig'] module-attribute
DiarizationParams

Bases: BaseModel

Per-request diarization options; maps to pyannote API payload. Use .to_api_dict() to emit API field names.

Source code in src/tnh_scholar/audio_processing/diarization/schemas.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class DiarizationParams(BaseModel):
    """
    Per-request diarization options; maps to pyannote API payload.
    Use .to_api_dict() to emit API field names.
    """

    model_config = ConfigDict(
        frozen=True,            # make instances immutable
        populate_by_name=True,  # allow using pythonic field names with aliases
        extra="forbid",         # catch accidental fields at construction
    )

    # Pythonic attribute -> API alias on dump
    num_speakers: int | Literal["auto"] | None = Field(
        default=None,
        alias="numSpeakers",
        description="Fixed number of speakers or 'auto' for detection.",
    )
    confidence: float | None = Field(
        default=None,
        ge=0.0,
        le=1.0,
        description="Confidence threshold for segments.",
    )
    webhook: AnyUrl | None = Field(
        default=None,
        description="Webhook URL for job status callbacks.",
    )

    def to_api_dict(self) -> dict[str, Any]:
        """Return payload dict using API field names (camelCase) and excluding Nones."""
        return self.model_dump(by_alias=True, exclude_none=True)
confidence = Field(default=None, ge=0.0, le=1.0, description='Confidence threshold for segments.') class-attribute instance-attribute
model_config = ConfigDict(frozen=True, populate_by_name=True, extra='forbid') class-attribute instance-attribute
num_speakers = Field(default=None, alias='numSpeakers', description="Fixed number of speakers or 'auto' for detection.") class-attribute instance-attribute
webhook = Field(default=None, description='Webhook URL for job status callbacks.') class-attribute instance-attribute
to_api_dict()

Return payload dict using API field names (camelCase) and excluding Nones.

Source code in src/tnh_scholar/audio_processing/diarization/schemas.py
65
66
67
def to_api_dict(self) -> dict[str, Any]:
    """Return payload dict using API field names (camelCase) and excluding Nones."""
    return self.model_dump(by_alias=True, exclude_none=True)
DiarizationProcessor

Orchestrator over a DiarizationService.

This layer delegates to the service for generation and handles persistence.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class DiarizationProcessor:
    """Orchestrator over a DiarizationService.

    This layer delegates to the service for generation and handles persistence.
    """

    def __init__(
        self,
        audio_file_path: Path,
        output_path: Optional[Path] = None,
        *,
        service: Optional[DiarizationService] = None,
        params: Optional[DiarizationParams] = None,
        api_key: Optional[str] = None,
        writer: Optional[ResultWriter] = None,
    ) -> None:
        self.audio_file_path: Path = audio_file_path.resolve()
        if not self.audio_file_path.exists():
            raise FileNotFoundError(f"Audio file not found: {audio_file_path}")

        # Default output path
        self.output_path: Path = (
            output_path.resolve()
            if output_path is not None
            else self.audio_file_path.parent / f"{self.audio_file_path.stem}{PYANNOTE_FILE_STR}.json"
        )

        # Service & config
        # If a concrete service is not provided, default to PyannoteService.
        # Only pass api_key to PyannoteClient if it is not None.
        default_client = PyannoteClient(api_key) if api_key is not None else PyannoteClient()
        self.service: DiarizationService = service or PyannoteService(default_client)
        self.params: Optional[DiarizationParams] = params
        self.writer: ResultWriter = writer or FileResultWriter()

        # Cached state
        self._last_response: Optional[DiarizationResponse] = None
        self._last_job_id: Optional[str] = None

    # ---- Two-phase job control (nice for UIs) --------------------------------

    def start(self) -> JobHandle:
        """Start a job and cache its job_id."""
        job_id = self.service.start(self.audio_file_path, params=self.params)
        if not job_id:
            raise RuntimeError("Diarization service returned empty job_id")
        self._last_job_id = job_id
        return JobHandle(job_id=job_id)

    def get_response(
        self, job: Optional[Union[JobHandle, str]] = None, *, wait_until_complete: bool = False
        ) -> DiarizationResponse:
        """Fetch current/final response for a job, caching the last response."""
        target_id: Optional[str]
        if isinstance(job, JobHandle):
            target_id = job.job_id
        else:
            target_id = job or self._last_job_id
        if target_id is None:
            raise ValueError(
                "No job_id provided and no previous job has been started. Call start() or pass a job_id."
            )
        resp = self.service.get_response(target_id, wait_until_complete=wait_until_complete)
        self._last_response = resp
        return resp

    # ---- One-shot path --------------------------------------------------------

    def generate(self, *, wait_until_complete: bool = True) -> DiarizationResponse:
        """One-shot convenience: delegate to the service and cache the response."""
        resp = self.service.generate(
            self.audio_file_path, 
            params=self.params, 
            wait_until_complete=wait_until_complete
            )
        self._last_response = resp
        # If the service exposes a job_id in the envelope, cache it for UIs
        # Do not fail on metadata issues; response is primary.
        try:
            job_id = getattr(resp, "job_id", None)
            if isinstance(job_id, str):
                self._last_job_id = job_id
        except (AttributeError, TypeError) as e:
            logger.warning(f"Could not extract job_id from response: {e}")
        return resp

    # ---- Persistence ----------------------------------------------------------

    def export(self, response: Optional[DiarizationResponse] = None) -> Path:
        """Write the provided or last response to `self.output_path`."""
        result = response or self._last_response
        if result is None:
            raise ValueError(
                "No DiarizationResponse available; call generate()/get_response() first or pass response="
                )
        return self.writer.write(self.output_path, result)
audio_file_path = audio_file_path.resolve() instance-attribute
output_path = output_path.resolve() if output_path is not None else self.audio_file_path.parent / f'{self.audio_file_path.stem}{PYANNOTE_FILE_STR}.json' instance-attribute
params = params instance-attribute
service = service or PyannoteService(default_client) instance-attribute
writer = writer or FileResultWriter() instance-attribute
__init__(audio_file_path, output_path=None, *, service=None, params=None, api_key=None, writer=None)
Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def __init__(
    self,
    audio_file_path: Path,
    output_path: Optional[Path] = None,
    *,
    service: Optional[DiarizationService] = None,
    params: Optional[DiarizationParams] = None,
    api_key: Optional[str] = None,
    writer: Optional[ResultWriter] = None,
) -> None:
    self.audio_file_path: Path = audio_file_path.resolve()
    if not self.audio_file_path.exists():
        raise FileNotFoundError(f"Audio file not found: {audio_file_path}")

    # Default output path
    self.output_path: Path = (
        output_path.resolve()
        if output_path is not None
        else self.audio_file_path.parent / f"{self.audio_file_path.stem}{PYANNOTE_FILE_STR}.json"
    )

    # Service & config
    # If a concrete service is not provided, default to PyannoteService.
    # Only pass api_key to PyannoteClient if it is not None.
    default_client = PyannoteClient(api_key) if api_key is not None else PyannoteClient()
    self.service: DiarizationService = service or PyannoteService(default_client)
    self.params: Optional[DiarizationParams] = params
    self.writer: ResultWriter = writer or FileResultWriter()

    # Cached state
    self._last_response: Optional[DiarizationResponse] = None
    self._last_job_id: Optional[str] = None
export(response=None)

Write the provided or last response to self.output_path.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
169
170
171
172
173
174
175
176
def export(self, response: Optional[DiarizationResponse] = None) -> Path:
    """Write the provided or last response to `self.output_path`."""
    result = response or self._last_response
    if result is None:
        raise ValueError(
            "No DiarizationResponse available; call generate()/get_response() first or pass response="
            )
    return self.writer.write(self.output_path, result)
generate(*, wait_until_complete=True)

One-shot convenience: delegate to the service and cache the response.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def generate(self, *, wait_until_complete: bool = True) -> DiarizationResponse:
    """One-shot convenience: delegate to the service and cache the response."""
    resp = self.service.generate(
        self.audio_file_path, 
        params=self.params, 
        wait_until_complete=wait_until_complete
        )
    self._last_response = resp
    # If the service exposes a job_id in the envelope, cache it for UIs
    # Do not fail on metadata issues; response is primary.
    try:
        job_id = getattr(resp, "job_id", None)
        if isinstance(job_id, str):
            self._last_job_id = job_id
    except (AttributeError, TypeError) as e:
        logger.warning(f"Could not extract job_id from response: {e}")
    return resp
get_response(job=None, *, wait_until_complete=False)

Fetch current/final response for a job, caching the last response.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def get_response(
    self, job: Optional[Union[JobHandle, str]] = None, *, wait_until_complete: bool = False
    ) -> DiarizationResponse:
    """Fetch current/final response for a job, caching the last response."""
    target_id: Optional[str]
    if isinstance(job, JobHandle):
        target_id = job.job_id
    else:
        target_id = job or self._last_job_id
    if target_id is None:
        raise ValueError(
            "No job_id provided and no previous job has been started. Call start() or pass a job_id."
        )
    resp = self.service.get_response(target_id, wait_until_complete=wait_until_complete)
    self._last_response = resp
    return resp
start()

Start a job and cache its job_id.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
122
123
124
125
126
127
128
def start(self) -> JobHandle:
    """Start a job and cache its job_id."""
    job_id = self.service.start(self.audio_file_path, params=self.params)
    if not job_id:
        raise RuntimeError("Diarization service returned empty job_id")
    self._last_job_id = job_id
    return JobHandle(job_id=job_id)
PyannoteClient

Client for interacting with the pyannote.ai speaker diarization API.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
class PyannoteClient:
    """Client for interacting with the pyannote.ai speaker diarization API."""

    def __init__(self, api_key: Optional[str] = None, config: Optional[PyannoteConfig] = None):
        """
        Initialize with API key.

        Args:
            api_key: Pyannote.ai API key (defaults to environment variable)
        """
        self.api_key = api_key or os.getenv("PYANNOTEAI_API_TOKEN")
        if not self.api_key:
            raise APIKeyError(
                "API key is required. Set PYANNOTEAI_API_TOKEN environment "
                "variable or pass as parameter"
            )

        self.config = config or PyannoteConfig()
        self.polling_config = self.config.polling_config

        # Upload-specific timeouts (longer than general calls)
        self.upload_timeout = self.config.upload_timeout
        self.upload_max_retries = self.config.upload_max_retries
        self.network_timeout = self.config.network_timeout

        self.headers = {"Authorization": f"Bearer {self.api_key}"}

    # -----------------------
    # Upload helpers
    # -----------------------
    def _create_media_id(self) -> str:
        """Generate a unique media ID."""
        timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
        return f"{self.config.media_prefix}{timestamp}"

    def _upload_file(self, file_path: Path, upload_url: str) -> bool:
        """
        Upload file to the provided URL.

        Args:
            file_path: Path to the file to upload
            upload_url: URL to upload to

        Returns:
            bool: True if upload successful, False otherwise
        """
        try:
            logger.info(f"Uploading file to Pyannote.ai: {file_path}")
            with open(file_path, "rb") as file_data:
                upload_response = requests.put(
                    upload_url,
                    data=file_data,
                    headers={"Content-Type": self.config.media_content_type},
                    timeout=self.upload_timeout,
                )

            upload_response.raise_for_status()
            logger.info("File uploaded successfully")
            return True

        except requests.RequestException as e:
            logger.error(f"Failed to upload file: {e}")
            return False

    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential_jitter(exp_base=2, initial=3, max=30),
        retry=retry_if_exception_type(
            (requests.RequestException, requests.Timeout, requests.ConnectionError)
            ),
    )
    def upload_audio(self, file_path: Path) -> Optional[str]:
        """
        Upload audio file with retry logic for network robustness.

        Retries on network errors with exponential backoff.
        Fails fast on permanent errors (auth, file not found, etc.).
        """
        try:
            if not file_path.exists() or not file_path.is_file():
                logger.error(f"Audio file not found or is not a file: {file_path}")
                return None
        except OSError as e:
            logger.error(f"Error accessing audio file '{file_path}': {e}")
            return None

        try:
            file_size_mb = file_path.stat().st_size / (1024 * 1024)
        except OSError as e:
            logger.error(f"Error reading file size for '{file_path}': {e}")
            return None

        logger.info(f"Starting upload of {file_path.name} ({file_size_mb:.1f}MB)")

        try:
            # Create media ID
            media_id = self._create_media_id()
            logger.debug(f"Created media ID: {media_id}")

            # Get upload URL (this is fast, use normal timeout)
            upload_url = self._data_upload_url(media_id)
            if not upload_url:
                return None

            # Upload file (this is slow, use extended timeout)
            if self._upload_file(file_path, upload_url):
                logger.info(f"Upload completed successfully: {media_id}")
                return media_id
            else:
                logger.error(f"Upload failed for {file_path.name}")
                return None

        except Exception as e:
            # Log but don't retry - let tenacity handle retries
            logger.error(f"Upload attempt failed: {e}")
            raise  # Re-raise for tenacity to handle

    def _data_upload_url(self, media_id: str) -> Optional[str]:
        response = requests.post(
            self.config.media_input_endpoint,
            headers=self.headers,
            json={"url": media_id},
            timeout=self.network_timeout,
        )
        upload_url = self._extract_response_info(
            response, "url", "No upload URL in API response"
        )
        logger.debug(f"Got upload URL for media ID: {media_id}")
        return upload_url

    def _extract_response_info(self, response, response_type, error_msg):
        response.raise_for_status()
        info = response.json()
        if result := info.get(response_type):
            return result
        else:
            raise ValueError(error_msg)

    # -----------------------
    # Start job
    # -----------------------
    def start_diarization(self, media_id: str, params: Optional[DiarizationParams] = None) -> Optional[str]:
        """
        Start diarization job with pyannote.ai API.

        Args:
            media_id: The media ID from upload_audio
            params: Optional parameters for diarization

        Returns:
            Optional[str]: The job ID if started successfully, None otherwise
        """
        try:
            return self._send_payload(media_id, params)
        except requests.RequestException as e:
            logger.error(f"API request failed: {e}")
            return None
        except ValueError as e:
            logger.error(f"Invalid API response: {e}")
            return None

    def _send_payload(self, media_id, params):
        payload: Dict[str, Any] = {"url": media_id}
        if params:
            payload |= params.to_api_dict()
            logger.info(f"Starting diarization with params: {params}")
        logger.debug(f"Full payload: {payload}")

        response = requests.post(self.config.diarize_endpoint, headers=self.headers, json=payload)
        job_id = self._extract_response_info(
            response, JOB_ID_FIELD, "API response missing job ID"
        )
        logger.info(f"Diarization job {job_id} started successfully")
        return job_id

    # -----------------------
    # Status / Polling
    # -----------------------
    def check_job_status(self, job_id: str) -> Optional[JobStatusResponse]:
        """
        Check the status of a diarization job.

        Returns a typed transport model (JobStatusResponse) or None on failure.
        """
        return self._check_status_with_retry(job_id)

    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential_jitter(exp_base=2, initial=1, max=10),
        retry=retry_if_exception_type(
            (requests.RequestException, requests.Timeout, requests.ConnectionError)
            ),
    )
    def _check_status_with_retry(self, job_id: str) -> Optional[JobStatusResponse]:
        """
        Check job status with network error retry logic.

        Retries network failures without killing the polling loop.
        Fails fast on API errors (auth, malformed response, etc.).

        Used as the status function in the JobPoller helper class.
        """
        try:
            endpoint = f"{self.config.job_status_endpoint}/{job_id}"
            response = requests.get(endpoint, headers=self.headers)
            response.raise_for_status()
            result = response.json()

            try:
                jsr = JobStatusResponse.model_validate(result)
            except Exception as ve:
                logger.error(f"Invalid status response for job {job_id}: {result} ({ve})")
                return None

            return jsr

        except requests.RequestException as e:
            logger.warning(f"Status check network error for job {job_id}: {e}")
            raise  # Let tenacity retry
        except Exception as e:
            logger.error(f"Unexpected status check error for job {job_id}: {e}")
            return None  # Don't retry on unexpected errors

    class JobPoller:
        """
        Generic job polling helper for long-running async jobs.
        """

        def __init__(self, status_fn, job_id: str, polling_config: PollingConfig):
            self.status_fn = status_fn
            self.job_id = job_id
            self.polling_config = polling_config
            self.poll_count = 0
            self.start_time = time.time()
            self.last_status: Optional[JobStatusResponse] = None
            self._last_error_reason: Optional[str] = None

        def _poll(self) -> JobStatusResponse | _PollSignal | None:
            self.poll_count += 1
            try:
                status_response = self.status_fn(self.job_id)
            except RetryError as e:
                self._last_error_reason = f"status check retry exhausted: {e}"
                logger.error(f"Status check retries exhausted for job {self.job_id}: {e}")
                return _PollSignal.STATUS_RETRY_EXHAUSTED

            if status_response is None:
                logger.error(f"Failed to get status for job {self.job_id} after retries")
                self._last_error_reason = "status response None"
                return None

            # track last known status for timeout / errors
            self.last_status = status_response

            status = status_response.status
            elapsed = time.time() - self.start_time

            if status == JobStatus.SUCCEEDED:
                logger.info(
                    f"Job {self.job_id} completed successfully after {elapsed:.1f}s ({self.poll_count} polls)"
                )
                return status_response

            if status == JobStatus.FAILED:
                logger.error(f"Job {self.job_id} failed: {status_response.server_error_msg}")
                return status_response

            # Job still running - calculate next poll interval
            logger.info(f"Job {self.job_id} status: {status} (elapsed: {elapsed:.1f}s)")
            return _PollSignal.CONTINUE

        # --- Internal builders to attach polling context and craft JSRs ---
        def _attach_context(
            self, 
            base: Optional[JobStatusResponse], 
            *, 
            outcome: PollOutcome, 
            elapsed: float, 
            msg: Optional[str] = None
            ) -> JobStatusResponse:
            """Return a JSR carrying outcome + poll context. If `base` exists, preserve its
            status/payload/server_error_msg unless `msg` overrides it. Otherwise, synthesize a minimal JSR."""
            if base is None:
                return JobStatusResponse(
                    job_id=self.job_id,
                    outcome=outcome,
                    status=None,
                    server_error_msg=msg,
                    payload=None,
                    polls=self.poll_count,
                    elapsed_s=elapsed,
                )
            return JobStatusResponse(
                job_id=self.job_id,
                outcome=outcome,
                status=base.status,
                server_error_msg=msg if msg is not None else base.server_error_msg,
                payload=base.payload,
                polls=self.poll_count,
                elapsed_s=elapsed,
            )

        def _on_terminal(self, jsr: JobStatusResponse, *, elapsed: float) -> JobStatusResponse:
            """Attach poll context to a terminal server response (SUCCEEDED/FAILED)."""
            return JobStatusResponse(
                job_id=self.job_id,
                outcome=PollOutcome.SUCCEEDED if jsr.status == JobStatus.SUCCEEDED else PollOutcome.FAILED,
                status=jsr.status,
                server_error_msg=jsr.server_error_msg,
                payload=jsr.payload,
                polls=self.poll_count,
                elapsed_s=elapsed,
            )

        def _on_status_retry_exhausted(self, *, elapsed: float) -> JobStatusResponse:
            return self._attach_context(
                self.last_status, 
                outcome=PollOutcome.NETWORK_ERROR, 
                elapsed=elapsed, 
                msg=self._last_error_reason
                )

        def _on_invalid_payload(self, *, elapsed: float) -> JobStatusResponse:
            return self._attach_context(
                self.last_status, 
                outcome=PollOutcome.ERROR, 
                elapsed=elapsed, 
                msg="invalid status payload"
                )

        def _on_timeout(self, err: RetryError, *, elapsed: float) -> JobStatusResponse:
            return self._attach_context(
                self.last_status, 
                outcome=PollOutcome.TIMEOUT, 
                elapsed=elapsed, 
                msg=str(err)
                )

        def _on_interrupt(self, *, elapsed: float) -> JobStatusResponse:
            return self._attach_context(
                self.last_status, 
                outcome=PollOutcome.INTERRUPTED, 
                elapsed=elapsed, 
                msg="KeyboardInterrupt"
                )

        def _on_exception(self, err: Exception, *, elapsed: float) -> JobStatusResponse:
            return self._attach_context(
                self.last_status, 
                outcome=PollOutcome.ERROR, 
                elapsed=elapsed, 
                msg=str(err)
                )

        def run(self) -> JobStatusResponse:
            try:
                result = self._setup_and_run_poll()
                elapsed = time.time() - self.start_time

                if isinstance(result, JobStatusResponse):
                    # Terminal SUCCEEDED/FAILED (or unexpected non-terminal delivered): attach context
                    return self._on_terminal(result, elapsed=elapsed)

                if result is _PollSignal.STATUS_RETRY_EXHAUSTED:
                    return self._on_status_retry_exhausted(elapsed=elapsed)

                # None indicates invalid status payload or unexpected branch
                return self._on_invalid_payload(elapsed=elapsed)

            except RetryError as e:
                # Outer polling timeout
                elapsed = time.time() - self.start_time
                logger.info(f"Polling timed out for job {self.job_id} after {elapsed:.1f}s")
                return self._on_timeout(e, elapsed=elapsed)
            except KeyboardInterrupt:
                elapsed = time.time() - self.start_time
                logger.info(f"Polling for job {self.job_id} interrupted by user. Exiting.")
                return self._on_interrupt(elapsed=elapsed)
            except Exception as e:
                elapsed = time.time() - self.start_time
                logger.error(f"Polling failed for job {self.job_id}: {e}")
                return self._on_exception(e, elapsed=elapsed)

        def _setup_and_run_poll(self) -> Optional[JobStatusResponse | _PollSignal]:
            cfg = self.polling_config
            stop_policy = stop_never if cfg.polling_timeout is None else stop_after_delay(cfg.polling_timeout)
            retrying = Retrying(
                retry=retry_if_result(lambda result: result is _PollSignal.CONTINUE),
                stop=stop_policy,
                wait=wait_exponential_jitter(
                    exp_base=cfg.exp_base,
                    initial=cfg.initial_poll_time,
                    max=cfg.max_interval,
                ),
                reraise=True,
            )
            result = retrying(self._poll)
            if isinstance(result, JobStatusResponse):
                return result
            # could be STATUS_RETRY_EXHAUSTED sentinel or None
            logger.info(f"Polling ended with result: {result}")
            return result

    def poll_job_until_complete(
        self,
        job_id: str,
        estimated_duration: Optional[float] = None,
        timeout: Optional[float] = None,
        wait_until_complete: Optional[bool] = False,
    ) -> JobStatusResponse:
        """
        Poll until the job reaches a terminal state or a client-side stop condition, and
        return a unified JobStatusResponse (JSR) that includes both the server payload
        and polling context via `outcome`, `polls`, and `elapsed_s`.

        Args:
            job_id: Remote job identifier to poll.
            estimated_duration: Optional hint; currently unused (reserved for adaptive backoff).
            timeout: Optional hard timeout in seconds for this poll call. If provided, it overrides
                     the client's default polling timeout. Ignored if `wait_until_complete` is True.
            wait_until_complete: If True, ignore timeout and poll indefinitely (subject to process lifetime).

        Returns:
            JobStatusResponse: unified transport + polling-context result.
        """
        if timeout is not None and wait_until_complete:
            raise ConfigurationError("Timeout cannot be set with wait_until_complete")

        # Derive an effective timeout for this call, without mutating client defaults
        effective_timeout = None if wait_until_complete else (
            timeout if timeout is not None else self.polling_config.polling_timeout
            )

        cfg = PollingConfig(
            polling_timeout=effective_timeout,
            initial_poll_time=self.polling_config.initial_poll_time,
            exp_base=self.polling_config.exp_base,
            max_interval=self.polling_config.max_interval,
        )

        poller = self.JobPoller(
            status_fn=self._check_status_with_retry,
            job_id=job_id,
            polling_config=cfg,
        )
        return poller.run()
api_key = api_key or os.getenv('PYANNOTEAI_API_TOKEN') instance-attribute
config = config or PyannoteConfig() instance-attribute
headers = {'Authorization': f'Bearer {self.api_key}'} instance-attribute
network_timeout = self.config.network_timeout instance-attribute
polling_config = self.config.polling_config instance-attribute
upload_max_retries = self.config.upload_max_retries instance-attribute
upload_timeout = self.config.upload_timeout instance-attribute
JobPoller

Generic job polling helper for long-running async jobs.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
class JobPoller:
    """
    Generic job polling helper for long-running async jobs.
    """

    def __init__(self, status_fn, job_id: str, polling_config: PollingConfig):
        self.status_fn = status_fn
        self.job_id = job_id
        self.polling_config = polling_config
        self.poll_count = 0
        self.start_time = time.time()
        self.last_status: Optional[JobStatusResponse] = None
        self._last_error_reason: Optional[str] = None

    def _poll(self) -> JobStatusResponse | _PollSignal | None:
        self.poll_count += 1
        try:
            status_response = self.status_fn(self.job_id)
        except RetryError as e:
            self._last_error_reason = f"status check retry exhausted: {e}"
            logger.error(f"Status check retries exhausted for job {self.job_id}: {e}")
            return _PollSignal.STATUS_RETRY_EXHAUSTED

        if status_response is None:
            logger.error(f"Failed to get status for job {self.job_id} after retries")
            self._last_error_reason = "status response None"
            return None

        # track last known status for timeout / errors
        self.last_status = status_response

        status = status_response.status
        elapsed = time.time() - self.start_time

        if status == JobStatus.SUCCEEDED:
            logger.info(
                f"Job {self.job_id} completed successfully after {elapsed:.1f}s ({self.poll_count} polls)"
            )
            return status_response

        if status == JobStatus.FAILED:
            logger.error(f"Job {self.job_id} failed: {status_response.server_error_msg}")
            return status_response

        # Job still running - calculate next poll interval
        logger.info(f"Job {self.job_id} status: {status} (elapsed: {elapsed:.1f}s)")
        return _PollSignal.CONTINUE

    # --- Internal builders to attach polling context and craft JSRs ---
    def _attach_context(
        self, 
        base: Optional[JobStatusResponse], 
        *, 
        outcome: PollOutcome, 
        elapsed: float, 
        msg: Optional[str] = None
        ) -> JobStatusResponse:
        """Return a JSR carrying outcome + poll context. If `base` exists, preserve its
        status/payload/server_error_msg unless `msg` overrides it. Otherwise, synthesize a minimal JSR."""
        if base is None:
            return JobStatusResponse(
                job_id=self.job_id,
                outcome=outcome,
                status=None,
                server_error_msg=msg,
                payload=None,
                polls=self.poll_count,
                elapsed_s=elapsed,
            )
        return JobStatusResponse(
            job_id=self.job_id,
            outcome=outcome,
            status=base.status,
            server_error_msg=msg if msg is not None else base.server_error_msg,
            payload=base.payload,
            polls=self.poll_count,
            elapsed_s=elapsed,
        )

    def _on_terminal(self, jsr: JobStatusResponse, *, elapsed: float) -> JobStatusResponse:
        """Attach poll context to a terminal server response (SUCCEEDED/FAILED)."""
        return JobStatusResponse(
            job_id=self.job_id,
            outcome=PollOutcome.SUCCEEDED if jsr.status == JobStatus.SUCCEEDED else PollOutcome.FAILED,
            status=jsr.status,
            server_error_msg=jsr.server_error_msg,
            payload=jsr.payload,
            polls=self.poll_count,
            elapsed_s=elapsed,
        )

    def _on_status_retry_exhausted(self, *, elapsed: float) -> JobStatusResponse:
        return self._attach_context(
            self.last_status, 
            outcome=PollOutcome.NETWORK_ERROR, 
            elapsed=elapsed, 
            msg=self._last_error_reason
            )

    def _on_invalid_payload(self, *, elapsed: float) -> JobStatusResponse:
        return self._attach_context(
            self.last_status, 
            outcome=PollOutcome.ERROR, 
            elapsed=elapsed, 
            msg="invalid status payload"
            )

    def _on_timeout(self, err: RetryError, *, elapsed: float) -> JobStatusResponse:
        return self._attach_context(
            self.last_status, 
            outcome=PollOutcome.TIMEOUT, 
            elapsed=elapsed, 
            msg=str(err)
            )

    def _on_interrupt(self, *, elapsed: float) -> JobStatusResponse:
        return self._attach_context(
            self.last_status, 
            outcome=PollOutcome.INTERRUPTED, 
            elapsed=elapsed, 
            msg="KeyboardInterrupt"
            )

    def _on_exception(self, err: Exception, *, elapsed: float) -> JobStatusResponse:
        return self._attach_context(
            self.last_status, 
            outcome=PollOutcome.ERROR, 
            elapsed=elapsed, 
            msg=str(err)
            )

    def run(self) -> JobStatusResponse:
        try:
            result = self._setup_and_run_poll()
            elapsed = time.time() - self.start_time

            if isinstance(result, JobStatusResponse):
                # Terminal SUCCEEDED/FAILED (or unexpected non-terminal delivered): attach context
                return self._on_terminal(result, elapsed=elapsed)

            if result is _PollSignal.STATUS_RETRY_EXHAUSTED:
                return self._on_status_retry_exhausted(elapsed=elapsed)

            # None indicates invalid status payload or unexpected branch
            return self._on_invalid_payload(elapsed=elapsed)

        except RetryError as e:
            # Outer polling timeout
            elapsed = time.time() - self.start_time
            logger.info(f"Polling timed out for job {self.job_id} after {elapsed:.1f}s")
            return self._on_timeout(e, elapsed=elapsed)
        except KeyboardInterrupt:
            elapsed = time.time() - self.start_time
            logger.info(f"Polling for job {self.job_id} interrupted by user. Exiting.")
            return self._on_interrupt(elapsed=elapsed)
        except Exception as e:
            elapsed = time.time() - self.start_time
            logger.error(f"Polling failed for job {self.job_id}: {e}")
            return self._on_exception(e, elapsed=elapsed)

    def _setup_and_run_poll(self) -> Optional[JobStatusResponse | _PollSignal]:
        cfg = self.polling_config
        stop_policy = stop_never if cfg.polling_timeout is None else stop_after_delay(cfg.polling_timeout)
        retrying = Retrying(
            retry=retry_if_result(lambda result: result is _PollSignal.CONTINUE),
            stop=stop_policy,
            wait=wait_exponential_jitter(
                exp_base=cfg.exp_base,
                initial=cfg.initial_poll_time,
                max=cfg.max_interval,
            ),
            reraise=True,
        )
        result = retrying(self._poll)
        if isinstance(result, JobStatusResponse):
            return result
        # could be STATUS_RETRY_EXHAUSTED sentinel or None
        logger.info(f"Polling ended with result: {result}")
        return result
job_id = job_id instance-attribute
last_status = None instance-attribute
poll_count = 0 instance-attribute
polling_config = polling_config instance-attribute
start_time = time.time() instance-attribute
status_fn = status_fn instance-attribute
__init__(status_fn, job_id, polling_config)
Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
294
295
296
297
298
299
300
301
def __init__(self, status_fn, job_id: str, polling_config: PollingConfig):
    self.status_fn = status_fn
    self.job_id = job_id
    self.polling_config = polling_config
    self.poll_count = 0
    self.start_time = time.time()
    self.last_status: Optional[JobStatusResponse] = None
    self._last_error_reason: Optional[str] = None
run()
Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
def run(self) -> JobStatusResponse:
    try:
        result = self._setup_and_run_poll()
        elapsed = time.time() - self.start_time

        if isinstance(result, JobStatusResponse):
            # Terminal SUCCEEDED/FAILED (or unexpected non-terminal delivered): attach context
            return self._on_terminal(result, elapsed=elapsed)

        if result is _PollSignal.STATUS_RETRY_EXHAUSTED:
            return self._on_status_retry_exhausted(elapsed=elapsed)

        # None indicates invalid status payload or unexpected branch
        return self._on_invalid_payload(elapsed=elapsed)

    except RetryError as e:
        # Outer polling timeout
        elapsed = time.time() - self.start_time
        logger.info(f"Polling timed out for job {self.job_id} after {elapsed:.1f}s")
        return self._on_timeout(e, elapsed=elapsed)
    except KeyboardInterrupt:
        elapsed = time.time() - self.start_time
        logger.info(f"Polling for job {self.job_id} interrupted by user. Exiting.")
        return self._on_interrupt(elapsed=elapsed)
    except Exception as e:
        elapsed = time.time() - self.start_time
        logger.error(f"Polling failed for job {self.job_id}: {e}")
        return self._on_exception(e, elapsed=elapsed)
__init__(api_key=None, config=None)

Initialize with API key.

Parameters:

Name Type Description Default
api_key Optional[str]

Pyannote.ai API key (defaults to environment variable)

None
Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def __init__(self, api_key: Optional[str] = None, config: Optional[PyannoteConfig] = None):
    """
    Initialize with API key.

    Args:
        api_key: Pyannote.ai API key (defaults to environment variable)
    """
    self.api_key = api_key or os.getenv("PYANNOTEAI_API_TOKEN")
    if not self.api_key:
        raise APIKeyError(
            "API key is required. Set PYANNOTEAI_API_TOKEN environment "
            "variable or pass as parameter"
        )

    self.config = config or PyannoteConfig()
    self.polling_config = self.config.polling_config

    # Upload-specific timeouts (longer than general calls)
    self.upload_timeout = self.config.upload_timeout
    self.upload_max_retries = self.config.upload_max_retries
    self.network_timeout = self.config.network_timeout

    self.headers = {"Authorization": f"Bearer {self.api_key}"}
check_job_status(job_id)

Check the status of a diarization job.

Returns a typed transport model (JobStatusResponse) or None on failure.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
244
245
246
247
248
249
250
def check_job_status(self, job_id: str) -> Optional[JobStatusResponse]:
    """
    Check the status of a diarization job.

    Returns a typed transport model (JobStatusResponse) or None on failure.
    """
    return self._check_status_with_retry(job_id)
poll_job_until_complete(job_id, estimated_duration=None, timeout=None, wait_until_complete=False)

Poll until the job reaches a terminal state or a client-side stop condition, and return a unified JobStatusResponse (JSR) that includes both the server payload and polling context via outcome, polls, and elapsed_s.

Parameters:

Name Type Description Default
job_id str

Remote job identifier to poll.

required
estimated_duration Optional[float]

Optional hint; currently unused (reserved for adaptive backoff).

None
timeout Optional[float]

Optional hard timeout in seconds for this poll call. If provided, it overrides the client's default polling timeout. Ignored if wait_until_complete is True.

None
wait_until_complete Optional[bool]

If True, ignore timeout and poll indefinitely (subject to process lifetime).

False

Returns:

Name Type Description
JobStatusResponse JobStatusResponse

unified transport + polling-context result.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
def poll_job_until_complete(
    self,
    job_id: str,
    estimated_duration: Optional[float] = None,
    timeout: Optional[float] = None,
    wait_until_complete: Optional[bool] = False,
) -> JobStatusResponse:
    """
    Poll until the job reaches a terminal state or a client-side stop condition, and
    return a unified JobStatusResponse (JSR) that includes both the server payload
    and polling context via `outcome`, `polls`, and `elapsed_s`.

    Args:
        job_id: Remote job identifier to poll.
        estimated_duration: Optional hint; currently unused (reserved for adaptive backoff).
        timeout: Optional hard timeout in seconds for this poll call. If provided, it overrides
                 the client's default polling timeout. Ignored if `wait_until_complete` is True.
        wait_until_complete: If True, ignore timeout and poll indefinitely (subject to process lifetime).

    Returns:
        JobStatusResponse: unified transport + polling-context result.
    """
    if timeout is not None and wait_until_complete:
        raise ConfigurationError("Timeout cannot be set with wait_until_complete")

    # Derive an effective timeout for this call, without mutating client defaults
    effective_timeout = None if wait_until_complete else (
        timeout if timeout is not None else self.polling_config.polling_timeout
        )

    cfg = PollingConfig(
        polling_timeout=effective_timeout,
        initial_poll_time=self.polling_config.initial_poll_time,
        exp_base=self.polling_config.exp_base,
        max_interval=self.polling_config.max_interval,
    )

    poller = self.JobPoller(
        status_fn=self._check_status_with_retry,
        job_id=job_id,
        polling_config=cfg,
    )
    return poller.run()
start_diarization(media_id, params=None)

Start diarization job with pyannote.ai API.

Parameters:

Name Type Description Default
media_id str

The media ID from upload_audio

required
params Optional[DiarizationParams]

Optional parameters for diarization

None

Returns:

Type Description
Optional[str]

Optional[str]: The job ID if started successfully, None otherwise

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def start_diarization(self, media_id: str, params: Optional[DiarizationParams] = None) -> Optional[str]:
    """
    Start diarization job with pyannote.ai API.

    Args:
        media_id: The media ID from upload_audio
        params: Optional parameters for diarization

    Returns:
        Optional[str]: The job ID if started successfully, None otherwise
    """
    try:
        return self._send_payload(media_id, params)
    except requests.RequestException as e:
        logger.error(f"API request failed: {e}")
        return None
    except ValueError as e:
        logger.error(f"Invalid API response: {e}")
        return None
upload_audio(file_path)

Upload audio file with retry logic for network robustness.

Retries on network errors with exponential backoff. Fails fast on permanent errors (auth, file not found, etc.).

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential_jitter(exp_base=2, initial=3, max=30),
    retry=retry_if_exception_type(
        (requests.RequestException, requests.Timeout, requests.ConnectionError)
        ),
)
def upload_audio(self, file_path: Path) -> Optional[str]:
    """
    Upload audio file with retry logic for network robustness.

    Retries on network errors with exponential backoff.
    Fails fast on permanent errors (auth, file not found, etc.).
    """
    try:
        if not file_path.exists() or not file_path.is_file():
            logger.error(f"Audio file not found or is not a file: {file_path}")
            return None
    except OSError as e:
        logger.error(f"Error accessing audio file '{file_path}': {e}")
        return None

    try:
        file_size_mb = file_path.stat().st_size / (1024 * 1024)
    except OSError as e:
        logger.error(f"Error reading file size for '{file_path}': {e}")
        return None

    logger.info(f"Starting upload of {file_path.name} ({file_size_mb:.1f}MB)")

    try:
        # Create media ID
        media_id = self._create_media_id()
        logger.debug(f"Created media ID: {media_id}")

        # Get upload URL (this is fast, use normal timeout)
        upload_url = self._data_upload_url(media_id)
        if not upload_url:
            return None

        # Upload file (this is slow, use extended timeout)
        if self._upload_file(file_path, upload_url):
            logger.info(f"Upload completed successfully: {media_id}")
            return media_id
        else:
            logger.error(f"Upload failed for {file_path.name}")
            return None

    except Exception as e:
        # Log but don't retry - let tenacity handle retries
        logger.error(f"Upload attempt failed: {e}")
        raise  # Re-raise for tenacity to handle
PyannoteConfig

Bases: BaseSettings

Configuration constants for Pyannote API.

Source code in src/tnh_scholar/audio_processing/diarization/config.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class PyannoteConfig(BaseSettings):
    """Configuration constants for Pyannote API."""
    model_config = SettingsConfigDict(
        env_file=".env",
        env_file_encoding="utf-8",
        case_sensitive = False,
        env_prefix = "PYANNOTE_",
        extra="ignore",
    )

    # API Endpoints
    base_url: str = "https://api.pyannote.ai/v1"

    @property
    def media_input_endpoint(self) -> str:
        return f"{self.base_url}/media/input"

    @property
    def diarize_endpoint(self) -> str:
        return f"{self.base_url}/diarize"

    @property
    def job_status_endpoint(self) -> str:
        return f"{self.base_url}/jobs"

    # Media
    media_prefix: str = "media://diarization-"
    media_content_type: str = "audio/mpeg"

    # Upload-specific settings
    upload_timeout: int = 300  # 5 minutes for large files
    upload_max_retries: int = 3

    # Network specific settings
    network_timeout: int = 3 # seconds

    # Polling
    polling_config: PollingConfig = PollingConfig()
base_url = 'https://api.pyannote.ai/v1' class-attribute instance-attribute
diarize_endpoint property
job_status_endpoint property
media_content_type = 'audio/mpeg' class-attribute instance-attribute
media_input_endpoint property
media_prefix = 'media://diarization-' class-attribute instance-attribute
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', case_sensitive=False, env_prefix='PYANNOTE_', extra='ignore') class-attribute instance-attribute
network_timeout = 3 class-attribute instance-attribute
polling_config = PollingConfig() class-attribute instance-attribute
upload_max_retries = 3 class-attribute instance-attribute
upload_timeout = 300 class-attribute instance-attribute
diarize(audio_file_path, output_path=None, *, params=None, service=None, api_key=None, wait_until_complete=True)

One-shot convenience to generate a result and (optionally) write it.

This returns the DiarizationResponse. Writing is left to callers or diarize_to_file below.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def diarize(
    audio_file_path: Path,
    output_path: Optional[Path] = None,
    *,
    params: Optional[DiarizationParams] = None,
    service: Optional[DiarizationService] = None,
    api_key: Optional[str] = None,
    wait_until_complete: bool = True,
) -> DiarizationResponse:
    """One-shot convenience to generate a result and (optionally) write it.

    This returns the `DiarizationResponse`. Writing is left to callers or
    `diarize_to_file` below.
    """
    processor = DiarizationProcessor(
        audio_file_path,
        output_path=output_path,
        service=service,
        params=params,
        api_key=api_key,
    )
    return processor.generate(wait_until_complete=wait_until_complete)
diarize_to_file(audio_file_path, output_path=None, *, params=None, service=None, api_key=None, wait_until_complete=True)

Convenience helper: generate then export to JSON if successful; returns response

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def diarize_to_file(
    audio_file_path: Path,
    output_path: Optional[Path] = None,
    *,
    params: Optional[DiarizationParams] = None,
    service: Optional[DiarizationService] = None,
    api_key: Optional[str] = None,
    wait_until_complete: bool = True,
) -> DiarizationResponse:
    """Convenience helper: generate then export to JSON if successful; returns response"""
    processor = DiarizationProcessor(
        audio_file_path,
        output_path=output_path,
        service=service,
        params=params,
        api_key=api_key,
    )
    response = processor.generate(wait_until_complete=wait_until_complete)
    if isinstance(response, DiarizationSucceeded):
        processor.export()
    return response
audio
__all__ = ['AudioHandler', 'AudioHandlerConfig'] module-attribute
AudioHandler

Isolates audio operations and external dependencies (pydub, ffmpeg).

Source code in src/tnh_scholar/audio_processing/diarization/audio/handler.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class AudioHandler:
    """Isolates audio operations and external dependencies (pydub, ffmpeg)."""

    def __init__(
        self, 
        config: AudioHandlerConfig = AudioHandlerConfig()
        ):
        self.config = config
        # Sensible fall‑backs for optional config values
        self.base_audio: AudioSegment
        self.output_format: Optional[str] = config.output_format
        self.input_format: Optional[str] = None

    def build_audio_chunk(self, chunk: DiarizationChunk, audio_file: Path) -> AudioChunk:
        """builds and sets the internal chunk.audio to be the new AudioChunk"""

        self._set_io_format(audio_file)
        base_audio = self._load_audio(audio_file)
        self._validate_segments(chunk)

        audio_segment = self._assemble_segments(chunk, base_audio)
        audio_chunk = AudioChunk(
            data=self._export_audio(audio_segment),
            start_ms=chunk.start_time,
            end_ms=chunk.end_time,
            format=self.output_format,
        )
        chunk.audio = audio_chunk
        return audio_chunk

    def export_audio_bytes(self, audio_segment: AudioSegment, format_str: Optional[str] = None) -> BytesIO:
        """Export AudioSegment to BytesIO for services/modules that require file-like objects."""
        return self._export_audio(audio_segment, format_str)

    def _set_io_format(self, audio_file: Path):
        formats = self.config.SUPPORTED_FORMATS
        suffix = audio_file.suffix.lstrip(".").lower()
        if not suffix or suffix not in formats:
            raise ValueError(
                f"Unsupported or missing audio file format: '{audio_file.suffix}'. "
                f"Supported formats are: {', '.join(sorted(formats))}"
            )
        self.input_format = suffix

        # Use input format if output format not specified
        self.output_format = self.output_format or self.input_format

    def _load_audio(self, audio_file: Path) -> AudioSegment:
        """Load the audio file and validate format."""
        return AudioSegment.from_file(audio_file, format=self.input_format)

    def _validate_segments(self, chunk: DiarizationChunk):
        """Ensure all segments have gap_before and spacing_time attributes set."""
        for i, segment in enumerate(chunk.segments):
            if not hasattr(segment, "gap_before") or not hasattr(segment, "spacing_time"):
                raise ValueError(
                    f"Segment at index {i} missing required gap annotations: "
                    f"gap_before={getattr(segment, 'gap_before', None)}, "
                    f"spacing_time={getattr(segment, 'spacing_time', None)}"
                )

    def _assemble_segments(self, chunk: DiarizationChunk, base_audio: AudioSegment) -> AudioSegment:
        """Assemble audio for the given diarization chunk using gap information."""
        assembled: AudioSegment = AudioSegment.empty()
        offset = 0
        prev_end: Optional[int] = None
        audio_length = len(base_audio)

        def _clamp(val, min_val, max_val):
            return max(min_val, min(val, max_val))

        def _add_silence(duration):
            nonlocal assembled, offset
            if duration > 0:
                assembled += AudioSegment.silent(duration=duration)
                offset += duration

        def _add_interval_audio(start, end):
            nonlocal assembled, offset
            start = _clamp(start, 0, audio_length)
            end = _clamp(end, 0, audio_length)
            if end > start:
                interval_audio = base_audio[start:end]
                assembled += interval_audio
                offset += len(interval_audio)

        def _add_segment_audio(start, end):
            nonlocal assembled, offset
            start = _clamp(start, 0, audio_length)
            end = _clamp(end, 0, audio_length)
            if end > start:
                seg_audio: AudioSegment = base_audio[start:end]
                assembled += seg_audio
                offset += len(seg_audio)
                return len(seg_audio)
            return 0

        for segment in chunk.segments:
            seg_start = int(segment.start)
            seg_end = int(segment.end)

            # Handle gap before segment
            if prev_end is not None:
                if self.config.silence_all_intervals or getattr(segment, "gap_before", False):
                    spacing_time = getattr(segment, "spacing_time", 0)
                    _add_silence(spacing_time)
                elif seg_start > prev_end:
                    _add_interval_audio(prev_end, seg_start)

            # Append current segment audio (clamped)
            segment.audio_map_start = offset
            _add_segment_audio(seg_start, seg_end)

            prev_end = seg_end

        return assembled

    # TODO: in _export_audio:
    # handle needed parameters for various export formats (can use kwargs for options)    
    def _export_audio(
        self, 
        audio_segment: AudioSegment,  
        format_str: Optional[str] = None
        ) -> BytesIO:
        """Export *audio segment* in the configured format and return raw bytes."""

        export_format = format_str or self.output_format
        supported_formats = self.config.SUPPORTED_FORMATS

        if not export_format:
            raise ConfigurationError("Cannot export. Output format not specified.")

        if export_format not in supported_formats:
            raise ValueError(
                f"Unsupported export format: '{export_format}'. "
                f"Supported formats are: {', '.join(sorted(supported_formats))}"
            )

        file_obj = BytesIO()
        try:
            audio_segment.export(file_obj, format=export_format)
            file_obj.seek(0)
        except Exception as e:
            logger.error(f"Failed to export audio segment: {e}")
            raise RuntimeError(f"Audio export failed: {e}") from e
        return file_obj
base_audio instance-attribute
config = config instance-attribute
input_format = None instance-attribute
output_format = config.output_format instance-attribute
__init__(config=AudioHandlerConfig())
Source code in src/tnh_scholar/audio_processing/diarization/audio/handler.py
35
36
37
38
39
40
41
42
43
def __init__(
    self, 
    config: AudioHandlerConfig = AudioHandlerConfig()
    ):
    self.config = config
    # Sensible fall‑backs for optional config values
    self.base_audio: AudioSegment
    self.output_format: Optional[str] = config.output_format
    self.input_format: Optional[str] = None
build_audio_chunk(chunk, audio_file)

builds and sets the internal chunk.audio to be the new AudioChunk

Source code in src/tnh_scholar/audio_processing/diarization/audio/handler.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def build_audio_chunk(self, chunk: DiarizationChunk, audio_file: Path) -> AudioChunk:
    """builds and sets the internal chunk.audio to be the new AudioChunk"""

    self._set_io_format(audio_file)
    base_audio = self._load_audio(audio_file)
    self._validate_segments(chunk)

    audio_segment = self._assemble_segments(chunk, base_audio)
    audio_chunk = AudioChunk(
        data=self._export_audio(audio_segment),
        start_ms=chunk.start_time,
        end_ms=chunk.end_time,
        format=self.output_format,
    )
    chunk.audio = audio_chunk
    return audio_chunk
export_audio_bytes(audio_segment, format_str=None)

Export AudioSegment to BytesIO for services/modules that require file-like objects.

Source code in src/tnh_scholar/audio_processing/diarization/audio/handler.py
62
63
64
def export_audio_bytes(self, audio_segment: AudioSegment, format_str: Optional[str] = None) -> BytesIO:
    """Export AudioSegment to BytesIO for services/modules that require file-like objects."""
    return self._export_audio(audio_segment, format_str)
AudioHandlerConfig

Bases: BaseSettings

Configuration settings for the AudioHandler. All audio time units are milliseconds (int)

Source code in src/tnh_scholar/audio_processing/diarization/audio/config.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class AudioHandlerConfig(BaseSettings):
    """
    Configuration settings for the AudioHandler.
    All audio time units are milliseconds (int)
    """

    output_format: Optional[str] = Field(
        default=None,
        description=
        "Audio output format used when exporting segments (e.g., 'wav', 'mp3')."
    )
    temp_storage_dir: Optional[Path] = Field(
        default=None,
        description=
        "Optional directory path for storing temporary audio files (currently unused)."
    )
    max_segment_length: Optional[int] = Field(
        default=None,
        description="Maximum allowed segment length (in milliseconds)."
    )
    silence_all_intervals: bool = Field(
        default=False,
        description="If True, replace every non-zero interval between consecutive diarization segments " 
        "with silence of length spacing_time."
    )
    SUPPORTED_FORMATS: frozenset = frozenset({"mp3", "wav", "flac", "ogg", "m4a", "mp4"})
    class Config:
        env_prefix = "AUDIO_HANDLER_"  # Optional: allow env vars like AUDIO_HANDLER_OUTPUT_FORMAT
SUPPORTED_FORMATS = frozenset({'mp3', 'wav', 'flac', 'ogg', 'm4a', 'mp4'}) class-attribute instance-attribute
max_segment_length = Field(default=None, description='Maximum allowed segment length (in milliseconds).') class-attribute instance-attribute
output_format = Field(default=None, description="Audio output format used when exporting segments (e.g., 'wav', 'mp3').") class-attribute instance-attribute
silence_all_intervals = Field(default=False, description='If True, replace every non-zero interval between consecutive diarization segments with silence of length spacing_time.') class-attribute instance-attribute
temp_storage_dir = Field(default=None, description='Optional directory path for storing temporary audio files (currently unused).') class-attribute instance-attribute
Config
Source code in src/tnh_scholar/audio_processing/diarization/audio/config.py
34
35
class Config:
    env_prefix = "AUDIO_HANDLER_"  # Optional: allow env vars like AUDIO_HANDLER_OUTPUT_FORMAT
env_prefix = 'AUDIO_HANDLER_' class-attribute instance-attribute
config
AudioHandlerConfig

Bases: BaseSettings

Configuration settings for the AudioHandler. All audio time units are milliseconds (int)

Source code in src/tnh_scholar/audio_processing/diarization/audio/config.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class AudioHandlerConfig(BaseSettings):
    """
    Configuration settings for the AudioHandler.
    All audio time units are milliseconds (int)
    """

    output_format: Optional[str] = Field(
        default=None,
        description=
        "Audio output format used when exporting segments (e.g., 'wav', 'mp3')."
    )
    temp_storage_dir: Optional[Path] = Field(
        default=None,
        description=
        "Optional directory path for storing temporary audio files (currently unused)."
    )
    max_segment_length: Optional[int] = Field(
        default=None,
        description="Maximum allowed segment length (in milliseconds)."
    )
    silence_all_intervals: bool = Field(
        default=False,
        description="If True, replace every non-zero interval between consecutive diarization segments " 
        "with silence of length spacing_time."
    )
    SUPPORTED_FORMATS: frozenset = frozenset({"mp3", "wav", "flac", "ogg", "m4a", "mp4"})
    class Config:
        env_prefix = "AUDIO_HANDLER_"  # Optional: allow env vars like AUDIO_HANDLER_OUTPUT_FORMAT
SUPPORTED_FORMATS = frozenset({'mp3', 'wav', 'flac', 'ogg', 'm4a', 'mp4'}) class-attribute instance-attribute
max_segment_length = Field(default=None, description='Maximum allowed segment length (in milliseconds).') class-attribute instance-attribute
output_format = Field(default=None, description="Audio output format used when exporting segments (e.g., 'wav', 'mp3').") class-attribute instance-attribute
silence_all_intervals = Field(default=False, description='If True, replace every non-zero interval between consecutive diarization segments with silence of length spacing_time.') class-attribute instance-attribute
temp_storage_dir = Field(default=None, description='Optional directory path for storing temporary audio files (currently unused).') class-attribute instance-attribute
Config
Source code in src/tnh_scholar/audio_processing/diarization/audio/config.py
34
35
class Config:
    env_prefix = "AUDIO_HANDLER_"  # Optional: allow env vars like AUDIO_HANDLER_OUTPUT_FORMAT
env_prefix = 'AUDIO_HANDLER_' class-attribute instance-attribute
handler

Audio handler utilities for slicing and assembling audio around diarization chunks. Designed for pipeline-friendly, single-responsibility methods so that higher-level services can remain agnostic of the underlying audio library.

This implementation purposely keeps logic minimal for testing.

logger = get_child_logger(__name__) module-attribute
AudioHandler

Isolates audio operations and external dependencies (pydub, ffmpeg).

Source code in src/tnh_scholar/audio_processing/diarization/audio/handler.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class AudioHandler:
    """Isolates audio operations and external dependencies (pydub, ffmpeg)."""

    def __init__(
        self, 
        config: AudioHandlerConfig = AudioHandlerConfig()
        ):
        self.config = config
        # Sensible fall‑backs for optional config values
        self.base_audio: AudioSegment
        self.output_format: Optional[str] = config.output_format
        self.input_format: Optional[str] = None

    def build_audio_chunk(self, chunk: DiarizationChunk, audio_file: Path) -> AudioChunk:
        """builds and sets the internal chunk.audio to be the new AudioChunk"""

        self._set_io_format(audio_file)
        base_audio = self._load_audio(audio_file)
        self._validate_segments(chunk)

        audio_segment = self._assemble_segments(chunk, base_audio)
        audio_chunk = AudioChunk(
            data=self._export_audio(audio_segment),
            start_ms=chunk.start_time,
            end_ms=chunk.end_time,
            format=self.output_format,
        )
        chunk.audio = audio_chunk
        return audio_chunk

    def export_audio_bytes(self, audio_segment: AudioSegment, format_str: Optional[str] = None) -> BytesIO:
        """Export AudioSegment to BytesIO for services/modules that require file-like objects."""
        return self._export_audio(audio_segment, format_str)

    def _set_io_format(self, audio_file: Path):
        formats = self.config.SUPPORTED_FORMATS
        suffix = audio_file.suffix.lstrip(".").lower()
        if not suffix or suffix not in formats:
            raise ValueError(
                f"Unsupported or missing audio file format: '{audio_file.suffix}'. "
                f"Supported formats are: {', '.join(sorted(formats))}"
            )
        self.input_format = suffix

        # Use input format if output format not specified
        self.output_format = self.output_format or self.input_format

    def _load_audio(self, audio_file: Path) -> AudioSegment:
        """Load the audio file and validate format."""
        return AudioSegment.from_file(audio_file, format=self.input_format)

    def _validate_segments(self, chunk: DiarizationChunk):
        """Ensure all segments have gap_before and spacing_time attributes set."""
        for i, segment in enumerate(chunk.segments):
            if not hasattr(segment, "gap_before") or not hasattr(segment, "spacing_time"):
                raise ValueError(
                    f"Segment at index {i} missing required gap annotations: "
                    f"gap_before={getattr(segment, 'gap_before', None)}, "
                    f"spacing_time={getattr(segment, 'spacing_time', None)}"
                )

    def _assemble_segments(self, chunk: DiarizationChunk, base_audio: AudioSegment) -> AudioSegment:
        """Assemble audio for the given diarization chunk using gap information."""
        assembled: AudioSegment = AudioSegment.empty()
        offset = 0
        prev_end: Optional[int] = None
        audio_length = len(base_audio)

        def _clamp(val, min_val, max_val):
            return max(min_val, min(val, max_val))

        def _add_silence(duration):
            nonlocal assembled, offset
            if duration > 0:
                assembled += AudioSegment.silent(duration=duration)
                offset += duration

        def _add_interval_audio(start, end):
            nonlocal assembled, offset
            start = _clamp(start, 0, audio_length)
            end = _clamp(end, 0, audio_length)
            if end > start:
                interval_audio = base_audio[start:end]
                assembled += interval_audio
                offset += len(interval_audio)

        def _add_segment_audio(start, end):
            nonlocal assembled, offset
            start = _clamp(start, 0, audio_length)
            end = _clamp(end, 0, audio_length)
            if end > start:
                seg_audio: AudioSegment = base_audio[start:end]
                assembled += seg_audio
                offset += len(seg_audio)
                return len(seg_audio)
            return 0

        for segment in chunk.segments:
            seg_start = int(segment.start)
            seg_end = int(segment.end)

            # Handle gap before segment
            if prev_end is not None:
                if self.config.silence_all_intervals or getattr(segment, "gap_before", False):
                    spacing_time = getattr(segment, "spacing_time", 0)
                    _add_silence(spacing_time)
                elif seg_start > prev_end:
                    _add_interval_audio(prev_end, seg_start)

            # Append current segment audio (clamped)
            segment.audio_map_start = offset
            _add_segment_audio(seg_start, seg_end)

            prev_end = seg_end

        return assembled

    # TODO: in _export_audio:
    # handle needed parameters for various export formats (can use kwargs for options)    
    def _export_audio(
        self, 
        audio_segment: AudioSegment,  
        format_str: Optional[str] = None
        ) -> BytesIO:
        """Export *audio segment* in the configured format and return raw bytes."""

        export_format = format_str or self.output_format
        supported_formats = self.config.SUPPORTED_FORMATS

        if not export_format:
            raise ConfigurationError("Cannot export. Output format not specified.")

        if export_format not in supported_formats:
            raise ValueError(
                f"Unsupported export format: '{export_format}'. "
                f"Supported formats are: {', '.join(sorted(supported_formats))}"
            )

        file_obj = BytesIO()
        try:
            audio_segment.export(file_obj, format=export_format)
            file_obj.seek(0)
        except Exception as e:
            logger.error(f"Failed to export audio segment: {e}")
            raise RuntimeError(f"Audio export failed: {e}") from e
        return file_obj
base_audio instance-attribute
config = config instance-attribute
input_format = None instance-attribute
output_format = config.output_format instance-attribute
__init__(config=AudioHandlerConfig())
Source code in src/tnh_scholar/audio_processing/diarization/audio/handler.py
35
36
37
38
39
40
41
42
43
def __init__(
    self, 
    config: AudioHandlerConfig = AudioHandlerConfig()
    ):
    self.config = config
    # Sensible fall‑backs for optional config values
    self.base_audio: AudioSegment
    self.output_format: Optional[str] = config.output_format
    self.input_format: Optional[str] = None
build_audio_chunk(chunk, audio_file)

builds and sets the internal chunk.audio to be the new AudioChunk

Source code in src/tnh_scholar/audio_processing/diarization/audio/handler.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def build_audio_chunk(self, chunk: DiarizationChunk, audio_file: Path) -> AudioChunk:
    """builds and sets the internal chunk.audio to be the new AudioChunk"""

    self._set_io_format(audio_file)
    base_audio = self._load_audio(audio_file)
    self._validate_segments(chunk)

    audio_segment = self._assemble_segments(chunk, base_audio)
    audio_chunk = AudioChunk(
        data=self._export_audio(audio_segment),
        start_ms=chunk.start_time,
        end_ms=chunk.end_time,
        format=self.output_format,
    )
    chunk.audio = audio_chunk
    return audio_chunk
export_audio_bytes(audio_segment, format_str=None)

Export AudioSegment to BytesIO for services/modules that require file-like objects.

Source code in src/tnh_scholar/audio_processing/diarization/audio/handler.py
62
63
64
def export_audio_bytes(self, audio_segment: AudioSegment, format_str: Optional[str] = None) -> BytesIO:
    """Export AudioSegment to BytesIO for services/modules that require file-like objects."""
    return self._export_audio(audio_segment, format_str)
chunker
logger = get_child_logger(__name__) module-attribute
DiarizationChunker

Class for chunking diarization results into processing units based on configurable duration targets.

Source code in src/tnh_scholar/audio_processing/diarization/chunker.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
class DiarizationChunker:
    """
    Class for chunking diarization results into processing units
    based on configurable duration targets.
    """

    def __init__(self, **config_options):
        """Initialize chunker with additional config_options."""
        self.config = ChunkConfig()

        self._handle_config_options(config_options)


    def extract_contiguous_chunks(self, segments: List[DiarizedSegment]) -> List[DiarizationChunk]:
        """
        Split diarization segments into contiguous chunks of
        approximately target_duration, without splitting on speaker changes.

        Args:
            segments: List of speaker segments from diarization

        Returns:
            List[Chunk]: Flat list of contiguous chunks
        """
        if not segments:
            return []

        extractor = self._ChunkExtractor(self.config, split_on_speaker_change=False)
        return extractor.extract(segments)

    class _ChunkExtractor:
        def __init__(self, config: ChunkConfig, split_on_speaker_change: bool = True):
            self.config = config
            self.split_on_speaker_change = split_on_speaker_change
            self.gap_threshold = self.config.gap_threshold
            self.spacing = self.config.gap_spacing_time
            self.chunks: List[DiarizationChunk] = []
            self.current_chunk_segments: List[DiarizedSegment] = []
            self.chunk_start: int = 0
            self.current_speaker = ""
            self.accumulated_time: int = 0

        @property
        def last_segment(self):
            return self.current_chunk_segments[-1] if self.current_chunk_segments else None

        def extract(self, segments: List[DiarizedSegment]) -> List[DiarizationChunk]:
            if not segments:
                return []

            self.chunk_start = int(segments[0].start)
            self.current_speaker = segments[0].speaker
            for segment in segments:
                self._check_segment_duration(segment)  
                self._process_segment(segment)

            self._finalize_last_chunk()
            return self.chunks

        def _process_segment(self, segment: DiarizedSegment):
            if self._should_split(segment):
                self._finalize_current_chunk(segment)
                self.chunk_start = int(segment.start)                
            self._add_segment(segment)

        def _add_segment(self, segment: DiarizedSegment):
            gap_time =  self._gap_time(segment)
            if gap_time > self.gap_threshold:
                segment.gap_before = True
                segment.spacing_time = self.spacing
                self.accumulated_time += int(segment.duration) + self.spacing
            else:
                segment.gap_before = False
                segment.spacing_time = max(gap_time, 0)
                self.accumulated_time += int(segment.duration) + gap_time
            self.current_chunk_segments.append(segment)
            self.current_speaker = segment.speaker

        def _gap_time(self, segment) -> int:
            if self.last_segment is None:
                # If no last_segment, this is first segment, so no gap.
                return 0 
            else:
                return segment.start - self.last_segment.end


        def _should_split(self, segment: DiarizedSegment) -> bool:
            gap_time = self._gap_time(segment)
            interval_time = gap_time if gap_time < self.gap_threshold else self.spacing
            accumulated_time = self.accumulated_time + interval_time + segment.duration
            return accumulated_time >= self.config.target_duration 

        def _finalize_current_chunk(self, next_segment: Optional[DiarizedSegment]):
            if self.current_chunk_segments:
                assert self.last_segment is not None
                self.chunks.append(
                    DiarizationChunk(
                        start_time=int(self.chunk_start),
                        end_time=int(self.last_segment.end), 
                        segments=self.current_chunk_segments.copy(),
                        audio=None,
                        accumulated_time=self.accumulated_time
                    )
                )
                self._reset_chunk_state(next_segment)             

        def _reset_chunk_state(self, next_segment):
            self.current_chunk_segments = []
            self.accumulated_time = 0
            if self.split_on_speaker_change and next_segment:
                    self.current_speaker = next_segment.speaker

        def _finalize_last_chunk(self):
            if self.current_chunk_segments:
                self._handle_final_segments()

        def _check_segment_duration(self, segment: DiarizedSegment) -> None:
            """Check if segment exceeds target duration and issue warning if needed."""
            if segment.duration > self.config.target_duration:
                logger.warning(f"Found segment longer than "
                            f"target duration: {segment.duration_sec:.0f}s")

        def _handle_final_segments(self) -> None:
            """Append final segments to last chunk if below min duration."""
            approx_remaining_time = sum(segment.duration for segment in self.current_chunk_segments)
            final_time = self.accumulated_time + approx_remaining_time
            min_time = self.config.min_duration

            if final_time < min_time and self.chunks:
               self._merge_to_last_chunk()
            else:
                # Create standalone chunk
                self._finalize_current_chunk(next_segment=None)

        def _merge_to_last_chunk(self):
            """Merge segments to the last chunk processed. self.chunks cannot be empty."""
            assert self.chunks
            self.chunks[-1].segments.extend(self.current_chunk_segments)
            self.chunks[-1].end_time = int(self.current_chunk_segments[-1].end)
            self.chunks[-1].accumulated_time += self.accumulated_time



    def _handle_config_options(self, config_options: Dict[str, Any]) -> None:
        """
        Handles additional configuration options, 
        logging a warning for unrecognized keys.
        """
        for key, value in config_options.items():
            if hasattr(self.config, key):
                setattr(self.config, key, value)
            else:
                logger.warning(f"Unrecognized configuration option: {key}")
config = ChunkConfig() instance-attribute
__init__(**config_options)

Initialize chunker with additional config_options.

Source code in src/tnh_scholar/audio_processing/diarization/chunker.py
20
21
22
23
24
def __init__(self, **config_options):
    """Initialize chunker with additional config_options."""
    self.config = ChunkConfig()

    self._handle_config_options(config_options)
extract_contiguous_chunks(segments)

Split diarization segments into contiguous chunks of approximately target_duration, without splitting on speaker changes.

Parameters:

Name Type Description Default
segments List[DiarizedSegment]

List of speaker segments from diarization

required

Returns:

Type Description
List[DiarizationChunk]

List[Chunk]: Flat list of contiguous chunks

Source code in src/tnh_scholar/audio_processing/diarization/chunker.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def extract_contiguous_chunks(self, segments: List[DiarizedSegment]) -> List[DiarizationChunk]:
    """
    Split diarization segments into contiguous chunks of
    approximately target_duration, without splitting on speaker changes.

    Args:
        segments: List of speaker segments from diarization

    Returns:
        List[Chunk]: Flat list of contiguous chunks
    """
    if not segments:
        return []

    extractor = self._ChunkExtractor(self.config, split_on_speaker_change=False)
    return extractor.extract(segments)
config
ChunkConfig

Bases: BaseSettings

Configuration for chunking

Source code in src/tnh_scholar/audio_processing/diarization/config.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
class ChunkConfig(BaseSettings):
    """Configuration for chunking"""
    model_config = SettingsConfigDict(
        env_file=".env",
        env_file_encoding="utf-8",
        case_sensitive = False,
        env_prefix = "CHUNK_",
        extra="ignore",
    )

    # Target duration for each chunk in milliseconds (default: 5 minutes = 300,000ms)
    target_duration: int = 300_000

    # Minimum duration for final chunk (in ms); shorter chunks are merged
    min_duration: int = 30_000 # 30 seconds

    # Maximum allowed gap between segments for audio processing
    gap_threshold: int = 4000

    # Spacing used between segments that are greater than gap threshold ms apart
    gap_spacing_time: int = 1000 
gap_spacing_time = 1000 class-attribute instance-attribute
gap_threshold = 4000 class-attribute instance-attribute
min_duration = 30000 class-attribute instance-attribute
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', case_sensitive=False, env_prefix='CHUNK_', extra='ignore') class-attribute instance-attribute
target_duration = 300000 class-attribute instance-attribute
DiarizationConfig

Bases: BaseSettings

Source code in src/tnh_scholar/audio_processing/diarization/config.py
148
149
150
151
152
153
154
155
156
157
158
159
class DiarizationConfig(BaseSettings):
    model_config = SettingsConfigDict(
        env_file=".env",
        env_file_encoding="utf-8",
        case_sensitive = False,
        env_prefix = "DIARIZATION_",
        extra="ignore",
    )
    speaker: SpeakerConfig = SpeakerConfig()
    chunk: ChunkConfig = ChunkConfig()
    language: LanguageConfig = LanguageConfig()
    mapping: MappingPolicy = MappingPolicy()
chunk = ChunkConfig() class-attribute instance-attribute
language = LanguageConfig() class-attribute instance-attribute
mapping = MappingPolicy() class-attribute instance-attribute
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', case_sensitive=False, env_prefix='DIARIZATION_', extra='ignore') class-attribute instance-attribute
speaker = SpeakerConfig() class-attribute instance-attribute
LanguageConfig

Bases: BaseSettings

Source code in src/tnh_scholar/audio_processing/diarization/config.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class LanguageConfig(BaseSettings):
    model_config = SettingsConfigDict(
        env_file=".env",
        env_file_encoding="utf-8",
        case_sensitive = False,
        env_prefix = "LANGUAGE_",
        extra="ignore",
    )
    # Duration for language probe sampling in milliseconds (default: 2 seconds)
    probe_time: int = 10_000

    # The file format used for language probe file-like objects
    export_format: str = "wav"

    # Default language
    default_language: str = "en"
default_language = 'en' class-attribute instance-attribute
export_format = 'wav' class-attribute instance-attribute
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', case_sensitive=False, env_prefix='LANGUAGE_', extra='ignore') class-attribute instance-attribute
probe_time = 10000 class-attribute instance-attribute
MappingPolicy

Bases: BaseSettings

Mapping policy for transport→domain shaping.

TODO (future parameters to consider): - min_segment_ms: int # drop micro-segments below threshold - merge_gap_ms: int # merge adjacent same-speaker if gap ≤ this - round_ms_to: int # quantize boundaries (e.g., 10ms) - confidence_floor: float | None # filter out low-confidence segments - suppress_unlabeled: bool # drop segments missing speaker id - attach_raw_payload: bool # persist raw API payload in metadata - version: int # policy versioning for reproducibility

Source code in src/tnh_scholar/audio_processing/diarization/config.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
class MappingPolicy(BaseSettings):
    """Mapping policy for transport→domain shaping.

    TODO (future parameters to consider):
    - min_segment_ms: int                # drop micro-segments below threshold
    - merge_gap_ms: int                  # merge adjacent same-speaker if gap ≤ this
    - round_ms_to: int                   # quantize boundaries (e.g., 10ms)
    - confidence_floor: float | None     # filter out low-confidence segments
    - suppress_unlabeled: bool           # drop segments missing speaker id
    - attach_raw_payload: bool           # persist raw API payload in metadata
    - version: int                       # policy versioning for reproducibility
    """
    model_config = SettingsConfigDict(
        env_file=".env",
        env_file_encoding="utf-8",
        case_sensitive = False,
        env_prefix = "MAPPING_",
        extra="ignore",
    )

    # Current, minimal policy (kept in sync with existing flags in use)
    default_speaker_label: str = "SPEAKER_00"
    single_speaker: bool = False
default_speaker_label = 'SPEAKER_00' class-attribute instance-attribute
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', case_sensitive=False, env_prefix='MAPPING_', extra='ignore') class-attribute instance-attribute
single_speaker = False class-attribute instance-attribute
PollingConfig

Bases: BaseSettings

Configuration constants for a generic polling class used to for Pyannote API polling.

Source code in src/tnh_scholar/audio_processing/diarization/config.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class PollingConfig(BaseSettings):
    """Configuration constants for a generic polling class used to for Pyannote API polling."""
    model_config = SettingsConfigDict(
        env_file=".env",
        env_file_encoding="utf-8",
        case_sensitive = False,
        env_prefix = "PYANNOTE_POLL_",
        extra="ignore",
    )

    polling_interval: int = 15
    polling_timeout: float | None = 300.0  # seconds. set to None for time unlimited
    initial_poll_time: int = 7
    exp_base: int = 2
    max_interval: int = 30
exp_base = 2 class-attribute instance-attribute
initial_poll_time = 7 class-attribute instance-attribute
max_interval = 30 class-attribute instance-attribute
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', case_sensitive=False, env_prefix='PYANNOTE_POLL_', extra='ignore') class-attribute instance-attribute
polling_interval = 15 class-attribute instance-attribute
polling_timeout = 300.0 class-attribute instance-attribute
PyannoteConfig

Bases: BaseSettings

Configuration constants for Pyannote API.

Source code in src/tnh_scholar/audio_processing/diarization/config.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class PyannoteConfig(BaseSettings):
    """Configuration constants for Pyannote API."""
    model_config = SettingsConfigDict(
        env_file=".env",
        env_file_encoding="utf-8",
        case_sensitive = False,
        env_prefix = "PYANNOTE_",
        extra="ignore",
    )

    # API Endpoints
    base_url: str = "https://api.pyannote.ai/v1"

    @property
    def media_input_endpoint(self) -> str:
        return f"{self.base_url}/media/input"

    @property
    def diarize_endpoint(self) -> str:
        return f"{self.base_url}/diarize"

    @property
    def job_status_endpoint(self) -> str:
        return f"{self.base_url}/jobs"

    # Media
    media_prefix: str = "media://diarization-"
    media_content_type: str = "audio/mpeg"

    # Upload-specific settings
    upload_timeout: int = 300  # 5 minutes for large files
    upload_max_retries: int = 3

    # Network specific settings
    network_timeout: int = 3 # seconds

    # Polling
    polling_config: PollingConfig = PollingConfig()
base_url = 'https://api.pyannote.ai/v1' class-attribute instance-attribute
diarize_endpoint property
job_status_endpoint property
media_content_type = 'audio/mpeg' class-attribute instance-attribute
media_input_endpoint property
media_prefix = 'media://diarization-' class-attribute instance-attribute
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', case_sensitive=False, env_prefix='PYANNOTE_', extra='ignore') class-attribute instance-attribute
network_timeout = 3 class-attribute instance-attribute
polling_config = PollingConfig() class-attribute instance-attribute
upload_max_retries = 3 class-attribute instance-attribute
upload_timeout = 300 class-attribute instance-attribute
SpeakerConfig

Bases: BaseSettings

Configuration settings for speaker block generation.

Source code in src/tnh_scholar/audio_processing/diarization/config.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class SpeakerConfig(BaseSettings):
    """Configuration settings for speaker block generation."""
    model_config = SettingsConfigDict(
        env_file=".env",
        env_file_encoding="utf-8",
        case_sensitive = False,
        env_prefix = "SPEAKER_",
        extra="ignore",
    )

    # Set the default gap allowed between segments that will allow grouping of
    # consecutive same-speaker segments
    same_speaker_gap_threshold: TimeMs = TimeMs.from_seconds(2)

    default_speaker_label: str = "SPEAKER_00"

    # If set to true, all speakers are set to default speaker label
    single_speaker: bool = False
default_speaker_label = 'SPEAKER_00' class-attribute instance-attribute
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', case_sensitive=False, env_prefix='SPEAKER_', extra='ignore') class-attribute instance-attribute
same_speaker_gap_threshold = TimeMs.from_seconds(2) class-attribute instance-attribute
single_speaker = False class-attribute instance-attribute
models
logger = get_child_logger(__name__) module-attribute
AudioChunk

Bases: BaseModel

Source code in src/tnh_scholar/audio_processing/diarization/models.py
176
177
178
179
180
181
182
183
184
185
class AudioChunk(BaseModel):
    data: BytesIO
    start_ms: int
    end_ms: int
    sample_rate: Optional[int] = None
    channels: Optional[int] = None
    format: Optional[str] = None

    class Config:
        arbitrary_types_allowed = True
channels = None class-attribute instance-attribute
data instance-attribute
end_ms instance-attribute
format = None class-attribute instance-attribute
sample_rate = None class-attribute instance-attribute
start_ms instance-attribute
Config
Source code in src/tnh_scholar/audio_processing/diarization/models.py
184
185
class Config:
    arbitrary_types_allowed = True
arbitrary_types_allowed = True class-attribute instance-attribute
AugDiarizedSegment

Bases: DiarizedSegment

DiarizedSegment with additional chunking/processing metadata.

This class extends DiarizationSegment and adds fields that are only set during chunk accumulation or downstream processing.

Attributes:

Name Type Description
gap_before bool

Indicates if there is a gap greater than the configured threshold before this segment. Set only during chunk accumulation.

spacing_time TimeMs

The spacing (in ms) between this and the previous segment, possibly adjusted if there is a gap before. Set only during chunk accumulation.

audio TNHAudioSegment

The audio data for this segment, sliced from the original audio.

Notes
  • The audio field is a slice of the original audio corresponding to this segment.
  • All time values (start, end, duration) are relative to the original audio.
  • When slicing or probing the audio field, use times relative to 0 (i.e., 0 to duration).
  • For language probing or any operation on audio, always use 0 as the start and duration as the end.
Source code in src/tnh_scholar/audio_processing/diarization/models.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
class AugDiarizedSegment(DiarizedSegment):
    """
    DiarizedSegment with additional chunking/processing metadata.

    This class extends `DiarizationSegment` and adds fields that are only set during
    chunk accumulation or downstream processing.

    Attributes:
        gap_before (bool): Indicates if there is a gap greater than the configured threshold
            before this segment. Set only during chunk accumulation.
        spacing_time (TimeMs): The spacing (in ms) between this and the previous segment,
            possibly adjusted if there is a gap before. Set only during chunk accumulation.
        audio (AudioSegment): The audio data for this segment, sliced from the original audio.

    Notes:
        - The `audio` field is a slice of the original audio corresponding to this segment.
        - All time values (start, end, duration) are relative to the original audio.
        - When slicing or probing the `audio` field, use times relative to 0 (i.e., 0 to duration).
        - For language probing or any operation on `audio`, 
          always use 0 as the start and `duration` as the end.
    """

    @property
    def relative_start(self) -> TimeMs:
        """Start time relative to the segment audio (always 0)."""
        return TimeMs(0)

    @property
    def relative_end(self) -> TimeMs:
        """End time relative to the segment audio (duration of segment)."""
        return self.duration

    gap_before_new: bool  # rename when ready to move over to using this class 
    spacing_time_new: TimeMs  # rename when ready to move over to using this class 
    audio: Optional[AudioSegment]

    @classmethod
    def from_segment(
        cls,
        segment: DiarizedSegment,
        gap_before: Optional[bool] = None,
        spacing_time_new: Optional[TimeMs] = None,
        audio: Optional[AudioSegment] = None,
        **kwargs
    ) -> "AugDiarizedSegment":
        """
        Create an AugDiarizedSegment from a DiarizedSegment, with optional new fields.
        Args:
            segment (DiarizedSegment): The base segment to copy fields from.
            gap_before_new (bool, optional): Value for gap_before_new. Defaults to False.
            spacing_time_new (TimeMs, optional): Value for spacing_time_new. Defaults to None.
            audio (AudioSegment, optional): Audio data for this segment. Defaults to None.
            **kwargs: Any additional fields to override.
        Returns:
            AugDiarizedSegment: The new augmented segment.
        """
        return cls(
            speaker=segment.speaker,
            start=segment.start,
            end=segment.end,
            audio_map_start=segment.audio_map_start,
            gap_before=segment.gap_before,
            spacing_time=segment.spacing_time,
            gap_before_new=segment.gap_before if segment.gap_before is not None else False,
            spacing_time_new=spacing_time_new if spacing_time_new is not None else TimeMs(0),
            audio=audio,
            **kwargs
        )

    class Config:
        arbitrary_types_allowed = True
audio instance-attribute
gap_before_new instance-attribute
relative_end property

End time relative to the segment audio (duration of segment).

relative_start property

Start time relative to the segment audio (always 0).

spacing_time_new instance-attribute
Config
Source code in src/tnh_scholar/audio_processing/diarization/models.py
172
173
class Config:
    arbitrary_types_allowed = True
arbitrary_types_allowed = True class-attribute instance-attribute
from_segment(segment, gap_before=None, spacing_time_new=None, audio=None, **kwargs) classmethod

Create an AugDiarizedSegment from a DiarizedSegment, with optional new fields. Args: segment (DiarizedSegment): The base segment to copy fields from. gap_before_new (bool, optional): Value for gap_before_new. Defaults to False. spacing_time_new (TimeMs, optional): Value for spacing_time_new. Defaults to None. audio (AudioSegment, optional): Audio data for this segment. Defaults to None. **kwargs: Any additional fields to override. Returns: AugDiarizedSegment: The new augmented segment.

Source code in src/tnh_scholar/audio_processing/diarization/models.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
@classmethod
def from_segment(
    cls,
    segment: DiarizedSegment,
    gap_before: Optional[bool] = None,
    spacing_time_new: Optional[TimeMs] = None,
    audio: Optional[AudioSegment] = None,
    **kwargs
) -> "AugDiarizedSegment":
    """
    Create an AugDiarizedSegment from a DiarizedSegment, with optional new fields.
    Args:
        segment (DiarizedSegment): The base segment to copy fields from.
        gap_before_new (bool, optional): Value for gap_before_new. Defaults to False.
        spacing_time_new (TimeMs, optional): Value for spacing_time_new. Defaults to None.
        audio (AudioSegment, optional): Audio data for this segment. Defaults to None.
        **kwargs: Any additional fields to override.
    Returns:
        AugDiarizedSegment: The new augmented segment.
    """
    return cls(
        speaker=segment.speaker,
        start=segment.start,
        end=segment.end,
        audio_map_start=segment.audio_map_start,
        gap_before=segment.gap_before,
        spacing_time=segment.spacing_time,
        gap_before_new=segment.gap_before if segment.gap_before is not None else False,
        spacing_time_new=spacing_time_new if spacing_time_new is not None else TimeMs(0),
        audio=audio,
        **kwargs
    )
DiarizationChunk

Bases: BaseModel

Represents a chunk of segments to be processed together.

Source code in src/tnh_scholar/audio_processing/diarization/models.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
class DiarizationChunk(BaseModel):
    """Represents a chunk of segments to be processed together."""
    start_time: int  # Start time in milliseconds
    end_time: int    # End time in milliseconds
    audio: Optional[AudioChunk] = None
    segments: List[DiarizedSegment]
    accumulated_time: int = 0
    class Config:
        arbitrary_types_allowed = True

    @property
    def total_duration(self) -> int:
        """Get chunk duration in milliseconds."""
        return self.end_time - self.start_time

    @property
    def total_duration_sec(self) -> float:
        return convert_ms_to_sec(self.total_duration)

    @property
    def total_duration_time(self) -> "TimeMs":
        return TimeMs(self.total_duration)
accumulated_time = 0 class-attribute instance-attribute
audio = None class-attribute instance-attribute
end_time instance-attribute
segments instance-attribute
start_time instance-attribute
total_duration property

Get chunk duration in milliseconds.

total_duration_sec property
total_duration_time property
Config
Source code in src/tnh_scholar/audio_processing/diarization/models.py
195
196
class Config:
    arbitrary_types_allowed = True
arbitrary_types_allowed = True class-attribute instance-attribute
DiarizedSegment

Bases: BaseModel

Represents a diarized audio segment for a single speaker.

Attributes:

Name Type Description
speaker str

The speaker label for this segment.

start TimeMs

Start time in milliseconds.

end TimeMs

End time in milliseconds.

audio_map_start Optional[int]

Location in the audio output file, if mapped.

gap_before Optional[bool]

Indicates if there is a gap greater than the configured threshold before this segment. This attribute is set exclusively by ChunkAccumulator.add_segment() and should be None until that point.

spacing_time Optional[int]

The spacing (in ms) between this and the previous segment, possibly adjusted if there is a gap before. This attribute is also set exclusively by ChunkAccumulator.add_segment() and should be None until that point.

Notes
  • gap_before and spacing_time are not set during initial diarization, but are assigned only when the segment is accumulated into a chunk for downstream audio handling.
  • These fields should be considered write-once and must not be mutated elsewhere.
Source code in src/tnh_scholar/audio_processing/diarization/models.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
class DiarizedSegment(BaseModel):
    """
    Represents a diarized audio segment for a single speaker.

    Attributes:
        speaker (str): The speaker label for this segment.
        start (TimeMs): Start time in milliseconds.
        end (TimeMs): End time in milliseconds.
        audio_map_start (Optional[int]): Location in the audio output file, if mapped.
        gap_before (Optional[bool]): Indicates if there is a gap greater than the configured threshold
            before this segment. This attribute is set exclusively by `ChunkAccumulator.add_segment()`
            and should be None until that point.
        spacing_time (Optional[int]): The spacing (in ms) between this and the previous segment,
            possibly adjusted if there is a gap before. This attribute is also set exclusively by
            `ChunkAccumulator.add_segment()` and should be None until that point.

    Notes:
        - `gap_before` and `spacing_time` are not set during initial diarization, but are assigned
          only when the segment is accumulated into a chunk for downstream audio handling.
        - These fields should be considered write-once and must not be mutated elsewhere.
    """
    speaker: str
    start: TimeMs  # Start time in milliseconds
    end: TimeMs    # End time in milliseconds
    audio_map_start: Optional[int] # location in the audio output file
    gap_before: Optional[bool] # indicates a gap > gap_threshold before this segment
    spacing_time: Optional[int] # spacing between this and previous segment; adjusted spacing if gap before

    @property
    def duration(self) -> "TimeMs":
        """Get segment duration in milliseconds."""
        return TimeMs(self.end - self.start)

    @property
    def duration_sec(self) -> float:
        return self.duration.to_seconds()

    # ------------------------------------------------------------------- #
    # IMPLEMENTATION NOTE
    # Convenience wrappers returning the new Time abstraction so can
    # start migrating call‑sites incrementally without touching the int‑ms
    # fields just yet.
    # ------------------------------------------------------------------- #
    @property
    def start_time(self) -> "TimeMs":
        return self.start

    @property
    def end_time(self) -> "TimeMs":
        return self.end

    @property
    def mapped_start(self):
        """Downstream registry field set by the audio handler"""
        return self.start if self.audio_map_start is None else self.audio_map_start

    @property
    def mapped_end(self):
        if self.audio_map_start is None:
            return self.end 
        else:
            return self.audio_map_start + int(self.duration) 

    def normalize(self) -> None:
        """Normalize the duration of the segment to be nonzero and validate start/end values."""
        # Validate that start and end are non-negative integers
        if not isinstance(self.start, int) or not isinstance(self.end, int):
            raise ValueError("Segment start and end must be integers, "
                             f"got start={self.start}, end={self.end}")
        if self.start < 0 or self.end < 0:
            raise ValueError(f"Segment start and end must be non-negative, "
                             f"got start={self.start}, end={self.end}")

        # Explicitly handle negative durations
        if self.end < self.start:
            logger.warning(
                f"Invalid segment duration detected: start ({self.start}) > end ({self.end}). "
                "Adjusting end to ensure minimum duration of 1."
            )
            self.end = TimeMs(self.start + 1)  # set minimum nonzero duration

        # Ensure minimum nonzero duration
        if self.start == self.end:
            logger.warning(
                f"Zero segment duration detected: start ({self.start}) == end ({self.end}). "
                "Adjusting end to ensure minimum duration of 1."
            )
            self.end = TimeMs(self.start + 1)  # set minimum nonzero duration
audio_map_start instance-attribute
duration property

Get segment duration in milliseconds.

duration_sec property
end instance-attribute
end_time property
gap_before instance-attribute
mapped_end property
mapped_start property

Downstream registry field set by the audio handler

spacing_time instance-attribute
speaker instance-attribute
start instance-attribute
start_time property
normalize()

Normalize the duration of the segment to be nonzero and validate start/end values.

Source code in src/tnh_scholar/audio_processing/diarization/models.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def normalize(self) -> None:
    """Normalize the duration of the segment to be nonzero and validate start/end values."""
    # Validate that start and end are non-negative integers
    if not isinstance(self.start, int) or not isinstance(self.end, int):
        raise ValueError("Segment start and end must be integers, "
                         f"got start={self.start}, end={self.end}")
    if self.start < 0 or self.end < 0:
        raise ValueError(f"Segment start and end must be non-negative, "
                         f"got start={self.start}, end={self.end}")

    # Explicitly handle negative durations
    if self.end < self.start:
        logger.warning(
            f"Invalid segment duration detected: start ({self.start}) > end ({self.end}). "
            "Adjusting end to ensure minimum duration of 1."
        )
        self.end = TimeMs(self.start + 1)  # set minimum nonzero duration

    # Ensure minimum nonzero duration
    if self.start == self.end:
        logger.warning(
            f"Zero segment duration detected: start ({self.start}) == end ({self.end}). "
            "Adjusting end to ensure minimum duration of 1."
        )
        self.end = TimeMs(self.start + 1)  # set minimum nonzero duration
SpeakerBlock

Bases: BaseModel

A block of contiguous or near-contiguous segments spoken by the same speaker.

Used as a higher-level abstraction over diarization segments to simplify chunking strategies (e.g., language-aware sampling, re-segmentation).

Source code in src/tnh_scholar/audio_processing/diarization/models.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
class SpeakerBlock(BaseModel):
    """A block of contiguous or near-contiguous segments spoken by the same speaker.

    Used as a higher-level abstraction over diarization segments to simplify
    chunking strategies (e.g., language-aware sampling, re-segmentation).
    """

    speaker: str
    segments: list["DiarizedSegment"]

    class Config:
        arbitrary_types_allowed = True

    @property
    def start(self) -> "TimeMs":
        return TimeMs(self.segments[0].start)

    @property
    def end(self) -> "TimeMs":
        return TimeMs(self.segments[-1].end)

    @property
    def duration(self) -> "TimeMs":
        return TimeMs(self.end - self.start)

    @property
    def duration_sec(self) -> float:
        return self.duration.to_seconds()

    @property
    def segment_count(self) -> int:
        return len(self.segments)

    def to_dict(self) -> dict:
        """custom serializer for SpeakerBlock with validation."""
        # Validate speaker
        if not isinstance(self.speaker, str) or not self.speaker:
            logger.error("SpeakerBlock.to_dict: 'speaker' must be a non-empty string.")
            raise ValueError("'speaker' must be a non-empty string.")

        # Validate segments
        if not isinstance(self.segments, list) or not self.segments:
            logger.error("SpeakerBlock.to_dict: 'segments' must be a non-empty list.")
            raise ValueError("'segments' must be a non-empty list of DiarizedSegment.")

        for idx, segment in enumerate(self.segments):
            if not isinstance(segment, DiarizedSegment):
                logger.error(f"SpeakerBlock.to_dict: Segment at index {idx} is not a DiarizedSegment.")
                raise TypeError(f"Segment at index {idx} is not a DiarizedSegment.")

        # Validate start/end/duration
        try:
            start = int(self.start)
            end = int(self.end)
            duration = int(self.duration)
            duration_sec = float(self.duration_sec)
            segment_count = int(self.segment_count)
        except Exception as e:
            logger.error(f"SpeakerBlock.to_dict: Error computing time fields: {e}")
            raise

        return {
            "speaker": self.speaker,
            "segments": [segment.model_dump() for segment in self.segments],
            "start": start,
            "end": end,
            "duration": duration,
            "duration_sec": duration_sec,
            "segment_count": segment_count,
        }

    @classmethod
    def from_dict(cls, data: dict) -> "SpeakerBlock":
        """
        Create a SpeakerBlock from a dictionary (output of to_dict).
        Args:
            data (dict): Dictionary with keys matching SpeakerBlock fields.
        Returns:
            SpeakerBlock: Deserialized SpeakerBlock instance.
        Raises:
            ValueError, TypeError: If validation fails.
        """
        if not isinstance(data, dict):
            logger.error("SpeakerBlock.from_dict: Input data must be a dictionary.")
            raise TypeError("Input data must be a dictionary.")

        if "speaker" not in data or not isinstance(data["speaker"], str) or not data["speaker"]:
            logger.error("SpeakerBlock.from_dict: 'speaker' must be a non-empty string.")
            raise ValueError("'speaker' must be a non-empty string.")

        if "segments" not in data or not isinstance(data["segments"], list) or not data["segments"]:
            logger.error("SpeakerBlock.from_dict: 'segments' must be a non-empty list.")
            raise ValueError("'segments' must be a non-empty list.")

        segments = []
        for idx, seg in enumerate(data["segments"]):
            if not isinstance(seg, dict):
                logger.error(f"SpeakerBlock.from_dict: Segment at index {idx} is not a dict.")
                raise TypeError(f"Segment at index {idx} is not a dict.")
            try:
                segment = DiarizedSegment(**seg)
            except Exception as e:
                logger.error(
                    f"SpeakerBlock.from_dict: Failed to construct DiarizedSegment at index {idx}: {e}"
                    )
                raise
            segments.append(segment)

        return cls(speaker=data["speaker"], segments=segments)
duration property
duration_sec property
end property
segment_count property
segments instance-attribute
speaker instance-attribute
start property
Config
Source code in src/tnh_scholar/audio_processing/diarization/models.py
221
222
class Config:
    arbitrary_types_allowed = True
arbitrary_types_allowed = True class-attribute instance-attribute
from_dict(data) classmethod

Create a SpeakerBlock from a dictionary (output of to_dict). Args: data (dict): Dictionary with keys matching SpeakerBlock fields. Returns: SpeakerBlock: Deserialized SpeakerBlock instance. Raises: ValueError, TypeError: If validation fails.

Source code in src/tnh_scholar/audio_processing/diarization/models.py
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
@classmethod
def from_dict(cls, data: dict) -> "SpeakerBlock":
    """
    Create a SpeakerBlock from a dictionary (output of to_dict).
    Args:
        data (dict): Dictionary with keys matching SpeakerBlock fields.
    Returns:
        SpeakerBlock: Deserialized SpeakerBlock instance.
    Raises:
        ValueError, TypeError: If validation fails.
    """
    if not isinstance(data, dict):
        logger.error("SpeakerBlock.from_dict: Input data must be a dictionary.")
        raise TypeError("Input data must be a dictionary.")

    if "speaker" not in data or not isinstance(data["speaker"], str) or not data["speaker"]:
        logger.error("SpeakerBlock.from_dict: 'speaker' must be a non-empty string.")
        raise ValueError("'speaker' must be a non-empty string.")

    if "segments" not in data or not isinstance(data["segments"], list) or not data["segments"]:
        logger.error("SpeakerBlock.from_dict: 'segments' must be a non-empty list.")
        raise ValueError("'segments' must be a non-empty list.")

    segments = []
    for idx, seg in enumerate(data["segments"]):
        if not isinstance(seg, dict):
            logger.error(f"SpeakerBlock.from_dict: Segment at index {idx} is not a dict.")
            raise TypeError(f"Segment at index {idx} is not a dict.")
        try:
            segment = DiarizedSegment(**seg)
        except Exception as e:
            logger.error(
                f"SpeakerBlock.from_dict: Failed to construct DiarizedSegment at index {idx}: {e}"
                )
            raise
        segments.append(segment)

    return cls(speaker=data["speaker"], segments=segments)
to_dict()

custom serializer for SpeakerBlock with validation.

Source code in src/tnh_scholar/audio_processing/diarization/models.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def to_dict(self) -> dict:
    """custom serializer for SpeakerBlock with validation."""
    # Validate speaker
    if not isinstance(self.speaker, str) or not self.speaker:
        logger.error("SpeakerBlock.to_dict: 'speaker' must be a non-empty string.")
        raise ValueError("'speaker' must be a non-empty string.")

    # Validate segments
    if not isinstance(self.segments, list) or not self.segments:
        logger.error("SpeakerBlock.to_dict: 'segments' must be a non-empty list.")
        raise ValueError("'segments' must be a non-empty list of DiarizedSegment.")

    for idx, segment in enumerate(self.segments):
        if not isinstance(segment, DiarizedSegment):
            logger.error(f"SpeakerBlock.to_dict: Segment at index {idx} is not a DiarizedSegment.")
            raise TypeError(f"Segment at index {idx} is not a DiarizedSegment.")

    # Validate start/end/duration
    try:
        start = int(self.start)
        end = int(self.end)
        duration = int(self.duration)
        duration_sec = float(self.duration_sec)
        segment_count = int(self.segment_count)
    except Exception as e:
        logger.error(f"SpeakerBlock.to_dict: Error computing time fields: {e}")
        raise

    return {
        "speaker": self.speaker,
        "segments": [segment.model_dump() for segment in self.segments],
        "start": start,
        "end": end,
        "duration": duration,
        "duration_sec": duration_sec,
        "segment_count": segment_count,
    }
protocols

Interfaces shared by diarization strategy classes.

AudioFetcher

Bases: Protocol

Abstract audio provider for probing a segment.

Source code in src/tnh_scholar/audio_processing/diarization/protocols.py
32
33
34
35
class AudioFetcher(Protocol):
    """Abstract audio provider for probing a segment."""

    def extract_audio(self, start_ms: int, end_ms: int) -> Path: ...
extract_audio(start_ms, end_ms)
Source code in src/tnh_scholar/audio_processing/diarization/protocols.py
35
def extract_audio(self, start_ms: int, end_ms: int) -> Path: ...
ChunkingStrategy

Bases: Protocol

Protocol every chunking strategy must satisfy.

Source code in src/tnh_scholar/audio_processing/diarization/protocols.py
24
25
26
27
28
29
class ChunkingStrategy(Protocol):
    """
    Protocol every chunking strategy must satisfy.
    """

    def extract(self, segments: List[DiarizedSegment]) -> List[DiarizationChunk]: ...
extract(segments)
Source code in src/tnh_scholar/audio_processing/diarization/protocols.py
29
def extract(self, segments: List[DiarizedSegment]) -> List[DiarizationChunk]: ...
DiarizationService

Bases: Protocol

Protocol for any diarization service.

Source code in src/tnh_scholar/audio_processing/diarization/protocols.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
@runtime_checkable
class DiarizationService(Protocol):
    """Protocol for any diarization service."""

    def start(self, audio_path: Path, params: Optional[DiarizationParams] = None) -> str:
        """Start a diarization job and return an opaque job_id.""" 
        ...


    def get_response(self, job_id: str, *, wait_until_complete: bool = False) -> DiarizationResponse: 
        """Return the current state or final result as a DiarizationResponse.

            When `wait_until_complete` is True, the service blocks until a terminal
            state (succeeded/failed/timeout) and returns that envelope.
            """
        ...

    def generate(
        self,
        audio_path: Path,
        params: Optional[DiarizationParams] = None,
        *,
        wait_until_complete: bool = True,
    ) -> DiarizationResponse: 
        ...
        """One-shot convenience: start + (optionally) wait + fetch + map.

        Implementations may optimize this path; default behavior can be
        start() followed by get_response().
        """
generate(audio_path, params=None, *, wait_until_complete=True)
Source code in src/tnh_scholar/audio_processing/diarization/protocols.py
60
61
62
63
64
65
66
67
68
69
70
71
72
def generate(
    self,
    audio_path: Path,
    params: Optional[DiarizationParams] = None,
    *,
    wait_until_complete: bool = True,
) -> DiarizationResponse: 
    ...
    """One-shot convenience: start + (optionally) wait + fetch + map.

    Implementations may optimize this path; default behavior can be
    start() followed by get_response().
    """
get_response(job_id, *, wait_until_complete=False)

Return the current state or final result as a DiarizationResponse.

When wait_until_complete is True, the service blocks until a terminal state (succeeded/failed/timeout) and returns that envelope.

Source code in src/tnh_scholar/audio_processing/diarization/protocols.py
52
53
54
55
56
57
58
def get_response(self, job_id: str, *, wait_until_complete: bool = False) -> DiarizationResponse: 
    """Return the current state or final result as a DiarizationResponse.

        When `wait_until_complete` is True, the service blocks until a terminal
        state (succeeded/failed/timeout) and returns that envelope.
        """
    ...
start(audio_path, params=None)

Start a diarization job and return an opaque job_id.

Source code in src/tnh_scholar/audio_processing/diarization/protocols.py
47
48
49
def start(self, audio_path: Path, params: Optional[DiarizationParams] = None) -> str:
    """Start a diarization job and return an opaque job_id.""" 
    ...
LanguageDetector

Bases: Protocol

Abstract language detector (e.g., fastText, Whisper-lang).

Source code in src/tnh_scholar/audio_processing/diarization/protocols.py
38
39
40
41
class LanguageDetector(Protocol):
    """Abstract language detector (e.g., fastText, Whisper-lang)."""

    def detect(self, audio: AudioSegment, format_str: str) -> Optional[str]: ...
detect(audio, format_str)
Source code in src/tnh_scholar/audio_processing/diarization/protocols.py
41
def detect(self, audio: AudioSegment, format_str: str) -> Optional[str]: ...
ResultWriter

Bases: Protocol

Port for persisting diarization results.

Source code in src/tnh_scholar/audio_processing/diarization/protocols.py
74
75
76
77
78
class ResultWriter(Protocol):
    """Port for persisting diarization results."""

    def write(self, path: Path, response: DiarizationResponse) -> Path:
        ...
write(path, response)
Source code in src/tnh_scholar/audio_processing/diarization/protocols.py
77
78
def write(self, path: Path, response: DiarizationResponse) -> Path:
    ...
SegmentAdapter

Bases: Protocol

Source code in src/tnh_scholar/audio_processing/diarization/protocols.py
17
18
19
20
21
class SegmentAdapter(Protocol):
    def to_segments(
        self, 
        data: Any
        ) -> List[DiarizedSegment]: ...
to_segments(data)
Source code in src/tnh_scholar/audio_processing/diarization/protocols.py
18
19
20
21
def to_segments(
    self, 
    data: Any
    ) -> List[DiarizedSegment]: ...
pyannote_adapter
logger = get_child_logger(__name__) module-attribute
PyannoteAdapter

Bases: SegmentAdapter

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_adapter.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
class PyannoteAdapter(SegmentAdapter):
    def __init__(self, config: DiarizationConfig = DiarizationConfig()):
        self.config = config

    def to_segments(self, data: Dict[str, List['PyannoteEntry']]) -> List[DiarizedSegment]:
        """
        Convert a pyannoteai diarization result dict to list of DiarizationSegment objects.
        """
        entries = self._extract_entries(data)
        valid_entries = self._validate_pyannote_entries(entries)
        segments: List[DiarizedSegment] = []
        for e in valid_entries:
            segment = DiarizedSegment(
                speaker=str(e.get("speaker", "SPEAKER_00")),
                start=TimeMs.from_seconds(float(e["start"])),
                end=TimeMs.from_seconds(float(e["end"])),
                audio_map_start=None,
                gap_before=None,
                spacing_time=None,
            )
            segments.append(segment)
        return self._sort_and_normalize_segments(segments)

    def to_response(
        self, jsr: JobStatusResponse
    ) -> DiarizationResponse:
        """
        Convert a JobStatusResponse to a DiarizationResponse (domain layer).
        """
        if self._is_successful(jsr):
            return self._build_succeeded(jsr)
        if self._is_outcome_failure(jsr):
            return self._build_outcome_failure(jsr)
        if self._is_api_failure(jsr):
            return self._build_api_failed(jsr)
        if self._is_pending(jsr):
            return self._build_pending(jsr)
        if self._is_running(jsr):
            return self._build_running(jsr)
        return self._build_fallback(jsr)

    def _extract_entries(self, payload: Dict[str, Any] | None) -> List[dict[str, Any]]:
        raw = payload or {}
        if isinstance(raw.get("diarization"), list):
            return list(raw["diarization"])
        segments = raw.get("segments")
        if isinstance(segments, list):
            return list(segments)
        ann = raw.get("annotation")
        if isinstance(ann, dict) and isinstance(ann.get("segments"), list):
            return list(ann["segments"])
        logger.warning(
            "Unexpected payload shape in _extract_entries: %r", payload
        )
        return []

    def _validate_pyannote_entries(self, entries: List[dict[str, Any]]) -> List[dict[str, Any]]:
        valid = []
        for e in entries:
            if not isinstance(e, dict):
                logger.warning("Entry is not a dict: %r", e)
                continue
            if any(k not in e for k in ("start", "end")):
                logger.warning("Missing 'start' or 'end' in entry: %r", e)
                continue
            try:
                float(e["start"])
                float(e["end"])
            except (ValueError, TypeError):
                logger.warning("Non-numeric 'start' or 'end' in entry: %r", e)
                continue
            valid.append(e)
        return valid

    def _sort_and_normalize_segments(
        self, segments: List[DiarizedSegment]
        ) -> List[DiarizedSegment]:
        self._sort_by_start(segments)
        for segment in segments:
            segment.normalize()
        return segments

    def _sort_by_start(self, segments: List[DiarizedSegment]) -> None:
        segments.sort(key=lambda segment: segment.start)

    def _map_outcome_to_error(self, outcome: PollOutcome, status: Optional[JobStatus]) -> ErrorCode:
        if outcome == PollOutcome.SUCCEEDED:
            logger.warning(
                "PollOutcome.SUCCEEDED was mapped to ErrorCode.UNKNOWN in map_outcome_to_error. "
                "This indicates a logic error."
            )
            return ErrorCode.UNKNOWN
        if outcome == PollOutcome.FAILED:
            return ErrorCode.API_ERROR
        if outcome == PollOutcome.TIMEOUT:
            return ErrorCode.TIMEOUT
        if outcome == PollOutcome.NETWORK_ERROR:
            return ErrorCode.TRANSIENT
        if outcome == PollOutcome.INTERRUPTED:
            return ErrorCode.CANCELLED
        if outcome == PollOutcome.ERROR:
            if status in (JobStatus.PENDING, JobStatus.RUNNING):
                return ErrorCode.TRANSIENT
            return ErrorCode.UNKNOWN
        return ErrorCode.UNKNOWN

    def _is_successful(self, jsr: JobStatusResponse) -> bool:
        return jsr.outcome == PollOutcome.SUCCEEDED and jsr.status == JobStatus.SUCCEEDED

    def _is_outcome_failure(self, jsr: JobStatusResponse) -> bool:
        return jsr.outcome in (
            PollOutcome.TIMEOUT,
            PollOutcome.NETWORK_ERROR,
            PollOutcome.INTERRUPTED,
            PollOutcome.ERROR,
        )

    def _is_api_failure(self, jsr: JobStatusResponse) -> bool:
        return jsr.status == JobStatus.FAILED

    def _is_pending(self, jsr: JobStatusResponse) -> bool:
        return jsr.status == JobStatus.PENDING

    def _is_running(self, jsr: JobStatusResponse) -> bool:
        return jsr.status == JobStatus.RUNNING

    def _build_succeeded(self, jsr: JobStatusResponse) -> DiarizationSucceeded:
        payload = jsr.payload or {}
        segments = self.to_segments(payload)
        num_speakers = payload.get("numSpeakers", payload.get("num_speakers"))
        return DiarizationSucceeded(
            status="succeeded",
            job_id=jsr.job_id,
            result=DiarizationResult(segments=segments, num_speakers=num_speakers, metadata=None),
            raw=jsr.model_dump(mode="json"),
        )

    def _build_outcome_failure(self, jsr: JobStatusResponse) -> DiarizationFailed:
        code = self._map_outcome_to_error(jsr.outcome, jsr.status)
        message = jsr.server_error_msg or "Null Message"
        return DiarizationFailed(
            status="failed",
            job_id=jsr.job_id,
            error=ErrorInfo(
                code=code,
                message=message,
                details={
                    "outcome": jsr.outcome.value,
                    "status": jsr.status.value if jsr.status else None,
                    "polls": jsr.polls,
                    "elapsed_s": jsr.elapsed_s,
                },
            ),
            raw=jsr.model_dump(mode="json"),
        )

    def _build_api_failed(self, jsr: JobStatusResponse) -> DiarizationFailed:
        return DiarizationFailed(
            status="failed",
            job_id=jsr.job_id,
            error=ErrorInfo(
                code=ErrorCode.API_ERROR,
                message=jsr.server_error_msg or "Remote job failed",
                details={"status": jsr.status.value if jsr.status else None},
            ),
            raw=jsr.model_dump(mode="json"),
        )

    def _build_pending(self, jsr: JobStatusResponse) -> DiarizationPending:
        return DiarizationPending(status="pending", job_id=jsr.job_id, raw=jsr.model_dump(mode="json"))

    def _build_running(self, jsr: JobStatusResponse) -> DiarizationRunning:
        return DiarizationRunning(status="running", job_id=jsr.job_id, raw=jsr.model_dump(mode="json"))

    def _build_fallback(self, jsr: JobStatusResponse) -> DiarizationFailed:
        return DiarizationFailed(
            status="failed",
            job_id=jsr.job_id,
            error=ErrorInfo(
                code=ErrorCode.UNKNOWN,
                message=(
                    f"Unknown outcome/status combination: outcome={jsr.outcome.value}, "
                    f"status={getattr(jsr.status, 'value', None)}"
                ),
                details={
                    "outcome": jsr.outcome.value,
                    "status": jsr.status.value if jsr.status else None,
                    "polls": jsr.polls,
                    "elapsed_s": jsr.elapsed_s,
                },
            ),
            raw=jsr.model_dump(mode="json"),
        )

    def failed_start(self):
        return DiarizationFailed(
            status="failed",
            job_id=None,
            error=ErrorInfo(
                code=ErrorCode.TRANSIENT,
                message=("Job failed to upload or start."),
                details=None,
            )

        )
config = config instance-attribute
__init__(config=DiarizationConfig())
Source code in src/tnh_scholar/audio_processing/diarization/pyannote_adapter.py
27
28
def __init__(self, config: DiarizationConfig = DiarizationConfig()):
    self.config = config
failed_start()
Source code in src/tnh_scholar/audio_processing/diarization/pyannote_adapter.py
220
221
222
223
224
225
226
227
228
229
230
def failed_start(self):
    return DiarizationFailed(
        status="failed",
        job_id=None,
        error=ErrorInfo(
            code=ErrorCode.TRANSIENT,
            message=("Job failed to upload or start."),
            details=None,
        )

    )
to_response(jsr)

Convert a JobStatusResponse to a DiarizationResponse (domain layer).

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_adapter.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def to_response(
    self, jsr: JobStatusResponse
) -> DiarizationResponse:
    """
    Convert a JobStatusResponse to a DiarizationResponse (domain layer).
    """
    if self._is_successful(jsr):
        return self._build_succeeded(jsr)
    if self._is_outcome_failure(jsr):
        return self._build_outcome_failure(jsr)
    if self._is_api_failure(jsr):
        return self._build_api_failed(jsr)
    if self._is_pending(jsr):
        return self._build_pending(jsr)
    if self._is_running(jsr):
        return self._build_running(jsr)
    return self._build_fallback(jsr)
to_segments(data)

Convert a pyannoteai diarization result dict to list of DiarizationSegment objects.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_adapter.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def to_segments(self, data: Dict[str, List['PyannoteEntry']]) -> List[DiarizedSegment]:
    """
    Convert a pyannoteai diarization result dict to list of DiarizationSegment objects.
    """
    entries = self._extract_entries(data)
    valid_entries = self._validate_pyannote_entries(entries)
    segments: List[DiarizedSegment] = []
    for e in valid_entries:
        segment = DiarizedSegment(
            speaker=str(e.get("speaker", "SPEAKER_00")),
            start=TimeMs.from_seconds(float(e["start"])),
            end=TimeMs.from_seconds(float(e["end"])),
            audio_map_start=None,
            gap_before=None,
            spacing_time=None,
        )
        segments.append(segment)
    return self._sort_and_normalize_segments(segments)
pyannote_client

pyannote_client.py

Client interface for interacting with the pyannote.ai speaker diarization API.

This module provides a robust, object-oriented client for uploading audio files, starting diarization jobs, polling for job completion, and retrieving results from the pyannote.ai API. It includes retry logic, configurable timeouts, and support for advanced diarization parameters.

Typical usage

client = PyannoteClient(api_key="your_api_key") media_id = client.upload_audio(Path("audio.mp3")) job_id = client.start_diarization(media_id) result = client.poll_job_until_complete(job_id)

JOB_ID_FIELD = 'jobId' module-attribute
logger = get_child_logger(__name__) module-attribute
APIKeyError

Bases: Exception

Raised when API key is missing or invalid.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
53
54
class APIKeyError(Exception):
    """Raised when API key is missing or invalid."""
PyannoteClient

Client for interacting with the pyannote.ai speaker diarization API.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
class PyannoteClient:
    """Client for interacting with the pyannote.ai speaker diarization API."""

    def __init__(self, api_key: Optional[str] = None, config: Optional[PyannoteConfig] = None):
        """
        Initialize with API key.

        Args:
            api_key: Pyannote.ai API key (defaults to environment variable)
        """
        self.api_key = api_key or os.getenv("PYANNOTEAI_API_TOKEN")
        if not self.api_key:
            raise APIKeyError(
                "API key is required. Set PYANNOTEAI_API_TOKEN environment "
                "variable or pass as parameter"
            )

        self.config = config or PyannoteConfig()
        self.polling_config = self.config.polling_config

        # Upload-specific timeouts (longer than general calls)
        self.upload_timeout = self.config.upload_timeout
        self.upload_max_retries = self.config.upload_max_retries
        self.network_timeout = self.config.network_timeout

        self.headers = {"Authorization": f"Bearer {self.api_key}"}

    # -----------------------
    # Upload helpers
    # -----------------------
    def _create_media_id(self) -> str:
        """Generate a unique media ID."""
        timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
        return f"{self.config.media_prefix}{timestamp}"

    def _upload_file(self, file_path: Path, upload_url: str) -> bool:
        """
        Upload file to the provided URL.

        Args:
            file_path: Path to the file to upload
            upload_url: URL to upload to

        Returns:
            bool: True if upload successful, False otherwise
        """
        try:
            logger.info(f"Uploading file to Pyannote.ai: {file_path}")
            with open(file_path, "rb") as file_data:
                upload_response = requests.put(
                    upload_url,
                    data=file_data,
                    headers={"Content-Type": self.config.media_content_type},
                    timeout=self.upload_timeout,
                )

            upload_response.raise_for_status()
            logger.info("File uploaded successfully")
            return True

        except requests.RequestException as e:
            logger.error(f"Failed to upload file: {e}")
            return False

    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential_jitter(exp_base=2, initial=3, max=30),
        retry=retry_if_exception_type(
            (requests.RequestException, requests.Timeout, requests.ConnectionError)
            ),
    )
    def upload_audio(self, file_path: Path) -> Optional[str]:
        """
        Upload audio file with retry logic for network robustness.

        Retries on network errors with exponential backoff.
        Fails fast on permanent errors (auth, file not found, etc.).
        """
        try:
            if not file_path.exists() or not file_path.is_file():
                logger.error(f"Audio file not found or is not a file: {file_path}")
                return None
        except OSError as e:
            logger.error(f"Error accessing audio file '{file_path}': {e}")
            return None

        try:
            file_size_mb = file_path.stat().st_size / (1024 * 1024)
        except OSError as e:
            logger.error(f"Error reading file size for '{file_path}': {e}")
            return None

        logger.info(f"Starting upload of {file_path.name} ({file_size_mb:.1f}MB)")

        try:
            # Create media ID
            media_id = self._create_media_id()
            logger.debug(f"Created media ID: {media_id}")

            # Get upload URL (this is fast, use normal timeout)
            upload_url = self._data_upload_url(media_id)
            if not upload_url:
                return None

            # Upload file (this is slow, use extended timeout)
            if self._upload_file(file_path, upload_url):
                logger.info(f"Upload completed successfully: {media_id}")
                return media_id
            else:
                logger.error(f"Upload failed for {file_path.name}")
                return None

        except Exception as e:
            # Log but don't retry - let tenacity handle retries
            logger.error(f"Upload attempt failed: {e}")
            raise  # Re-raise for tenacity to handle

    def _data_upload_url(self, media_id: str) -> Optional[str]:
        response = requests.post(
            self.config.media_input_endpoint,
            headers=self.headers,
            json={"url": media_id},
            timeout=self.network_timeout,
        )
        upload_url = self._extract_response_info(
            response, "url", "No upload URL in API response"
        )
        logger.debug(f"Got upload URL for media ID: {media_id}")
        return upload_url

    def _extract_response_info(self, response, response_type, error_msg):
        response.raise_for_status()
        info = response.json()
        if result := info.get(response_type):
            return result
        else:
            raise ValueError(error_msg)

    # -----------------------
    # Start job
    # -----------------------
    def start_diarization(self, media_id: str, params: Optional[DiarizationParams] = None) -> Optional[str]:
        """
        Start diarization job with pyannote.ai API.

        Args:
            media_id: The media ID from upload_audio
            params: Optional parameters for diarization

        Returns:
            Optional[str]: The job ID if started successfully, None otherwise
        """
        try:
            return self._send_payload(media_id, params)
        except requests.RequestException as e:
            logger.error(f"API request failed: {e}")
            return None
        except ValueError as e:
            logger.error(f"Invalid API response: {e}")
            return None

    def _send_payload(self, media_id, params):
        payload: Dict[str, Any] = {"url": media_id}
        if params:
            payload |= params.to_api_dict()
            logger.info(f"Starting diarization with params: {params}")
        logger.debug(f"Full payload: {payload}")

        response = requests.post(self.config.diarize_endpoint, headers=self.headers, json=payload)
        job_id = self._extract_response_info(
            response, JOB_ID_FIELD, "API response missing job ID"
        )
        logger.info(f"Diarization job {job_id} started successfully")
        return job_id

    # -----------------------
    # Status / Polling
    # -----------------------
    def check_job_status(self, job_id: str) -> Optional[JobStatusResponse]:
        """
        Check the status of a diarization job.

        Returns a typed transport model (JobStatusResponse) or None on failure.
        """
        return self._check_status_with_retry(job_id)

    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential_jitter(exp_base=2, initial=1, max=10),
        retry=retry_if_exception_type(
            (requests.RequestException, requests.Timeout, requests.ConnectionError)
            ),
    )
    def _check_status_with_retry(self, job_id: str) -> Optional[JobStatusResponse]:
        """
        Check job status with network error retry logic.

        Retries network failures without killing the polling loop.
        Fails fast on API errors (auth, malformed response, etc.).

        Used as the status function in the JobPoller helper class.
        """
        try:
            endpoint = f"{self.config.job_status_endpoint}/{job_id}"
            response = requests.get(endpoint, headers=self.headers)
            response.raise_for_status()
            result = response.json()

            try:
                jsr = JobStatusResponse.model_validate(result)
            except Exception as ve:
                logger.error(f"Invalid status response for job {job_id}: {result} ({ve})")
                return None

            return jsr

        except requests.RequestException as e:
            logger.warning(f"Status check network error for job {job_id}: {e}")
            raise  # Let tenacity retry
        except Exception as e:
            logger.error(f"Unexpected status check error for job {job_id}: {e}")
            return None  # Don't retry on unexpected errors

    class JobPoller:
        """
        Generic job polling helper for long-running async jobs.
        """

        def __init__(self, status_fn, job_id: str, polling_config: PollingConfig):
            self.status_fn = status_fn
            self.job_id = job_id
            self.polling_config = polling_config
            self.poll_count = 0
            self.start_time = time.time()
            self.last_status: Optional[JobStatusResponse] = None
            self._last_error_reason: Optional[str] = None

        def _poll(self) -> JobStatusResponse | _PollSignal | None:
            self.poll_count += 1
            try:
                status_response = self.status_fn(self.job_id)
            except RetryError as e:
                self._last_error_reason = f"status check retry exhausted: {e}"
                logger.error(f"Status check retries exhausted for job {self.job_id}: {e}")
                return _PollSignal.STATUS_RETRY_EXHAUSTED

            if status_response is None:
                logger.error(f"Failed to get status for job {self.job_id} after retries")
                self._last_error_reason = "status response None"
                return None

            # track last known status for timeout / errors
            self.last_status = status_response

            status = status_response.status
            elapsed = time.time() - self.start_time

            if status == JobStatus.SUCCEEDED:
                logger.info(
                    f"Job {self.job_id} completed successfully after {elapsed:.1f}s ({self.poll_count} polls)"
                )
                return status_response

            if status == JobStatus.FAILED:
                logger.error(f"Job {self.job_id} failed: {status_response.server_error_msg}")
                return status_response

            # Job still running - calculate next poll interval
            logger.info(f"Job {self.job_id} status: {status} (elapsed: {elapsed:.1f}s)")
            return _PollSignal.CONTINUE

        # --- Internal builders to attach polling context and craft JSRs ---
        def _attach_context(
            self, 
            base: Optional[JobStatusResponse], 
            *, 
            outcome: PollOutcome, 
            elapsed: float, 
            msg: Optional[str] = None
            ) -> JobStatusResponse:
            """Return a JSR carrying outcome + poll context. If `base` exists, preserve its
            status/payload/server_error_msg unless `msg` overrides it. Otherwise, synthesize a minimal JSR."""
            if base is None:
                return JobStatusResponse(
                    job_id=self.job_id,
                    outcome=outcome,
                    status=None,
                    server_error_msg=msg,
                    payload=None,
                    polls=self.poll_count,
                    elapsed_s=elapsed,
                )
            return JobStatusResponse(
                job_id=self.job_id,
                outcome=outcome,
                status=base.status,
                server_error_msg=msg if msg is not None else base.server_error_msg,
                payload=base.payload,
                polls=self.poll_count,
                elapsed_s=elapsed,
            )

        def _on_terminal(self, jsr: JobStatusResponse, *, elapsed: float) -> JobStatusResponse:
            """Attach poll context to a terminal server response (SUCCEEDED/FAILED)."""
            return JobStatusResponse(
                job_id=self.job_id,
                outcome=PollOutcome.SUCCEEDED if jsr.status == JobStatus.SUCCEEDED else PollOutcome.FAILED,
                status=jsr.status,
                server_error_msg=jsr.server_error_msg,
                payload=jsr.payload,
                polls=self.poll_count,
                elapsed_s=elapsed,
            )

        def _on_status_retry_exhausted(self, *, elapsed: float) -> JobStatusResponse:
            return self._attach_context(
                self.last_status, 
                outcome=PollOutcome.NETWORK_ERROR, 
                elapsed=elapsed, 
                msg=self._last_error_reason
                )

        def _on_invalid_payload(self, *, elapsed: float) -> JobStatusResponse:
            return self._attach_context(
                self.last_status, 
                outcome=PollOutcome.ERROR, 
                elapsed=elapsed, 
                msg="invalid status payload"
                )

        def _on_timeout(self, err: RetryError, *, elapsed: float) -> JobStatusResponse:
            return self._attach_context(
                self.last_status, 
                outcome=PollOutcome.TIMEOUT, 
                elapsed=elapsed, 
                msg=str(err)
                )

        def _on_interrupt(self, *, elapsed: float) -> JobStatusResponse:
            return self._attach_context(
                self.last_status, 
                outcome=PollOutcome.INTERRUPTED, 
                elapsed=elapsed, 
                msg="KeyboardInterrupt"
                )

        def _on_exception(self, err: Exception, *, elapsed: float) -> JobStatusResponse:
            return self._attach_context(
                self.last_status, 
                outcome=PollOutcome.ERROR, 
                elapsed=elapsed, 
                msg=str(err)
                )

        def run(self) -> JobStatusResponse:
            try:
                result = self._setup_and_run_poll()
                elapsed = time.time() - self.start_time

                if isinstance(result, JobStatusResponse):
                    # Terminal SUCCEEDED/FAILED (or unexpected non-terminal delivered): attach context
                    return self._on_terminal(result, elapsed=elapsed)

                if result is _PollSignal.STATUS_RETRY_EXHAUSTED:
                    return self._on_status_retry_exhausted(elapsed=elapsed)

                # None indicates invalid status payload or unexpected branch
                return self._on_invalid_payload(elapsed=elapsed)

            except RetryError as e:
                # Outer polling timeout
                elapsed = time.time() - self.start_time
                logger.info(f"Polling timed out for job {self.job_id} after {elapsed:.1f}s")
                return self._on_timeout(e, elapsed=elapsed)
            except KeyboardInterrupt:
                elapsed = time.time() - self.start_time
                logger.info(f"Polling for job {self.job_id} interrupted by user. Exiting.")
                return self._on_interrupt(elapsed=elapsed)
            except Exception as e:
                elapsed = time.time() - self.start_time
                logger.error(f"Polling failed for job {self.job_id}: {e}")
                return self._on_exception(e, elapsed=elapsed)

        def _setup_and_run_poll(self) -> Optional[JobStatusResponse | _PollSignal]:
            cfg = self.polling_config
            stop_policy = stop_never if cfg.polling_timeout is None else stop_after_delay(cfg.polling_timeout)
            retrying = Retrying(
                retry=retry_if_result(lambda result: result is _PollSignal.CONTINUE),
                stop=stop_policy,
                wait=wait_exponential_jitter(
                    exp_base=cfg.exp_base,
                    initial=cfg.initial_poll_time,
                    max=cfg.max_interval,
                ),
                reraise=True,
            )
            result = retrying(self._poll)
            if isinstance(result, JobStatusResponse):
                return result
            # could be STATUS_RETRY_EXHAUSTED sentinel or None
            logger.info(f"Polling ended with result: {result}")
            return result

    def poll_job_until_complete(
        self,
        job_id: str,
        estimated_duration: Optional[float] = None,
        timeout: Optional[float] = None,
        wait_until_complete: Optional[bool] = False,
    ) -> JobStatusResponse:
        """
        Poll until the job reaches a terminal state or a client-side stop condition, and
        return a unified JobStatusResponse (JSR) that includes both the server payload
        and polling context via `outcome`, `polls`, and `elapsed_s`.

        Args:
            job_id: Remote job identifier to poll.
            estimated_duration: Optional hint; currently unused (reserved for adaptive backoff).
            timeout: Optional hard timeout in seconds for this poll call. If provided, it overrides
                     the client's default polling timeout. Ignored if `wait_until_complete` is True.
            wait_until_complete: If True, ignore timeout and poll indefinitely (subject to process lifetime).

        Returns:
            JobStatusResponse: unified transport + polling-context result.
        """
        if timeout is not None and wait_until_complete:
            raise ConfigurationError("Timeout cannot be set with wait_until_complete")

        # Derive an effective timeout for this call, without mutating client defaults
        effective_timeout = None if wait_until_complete else (
            timeout if timeout is not None else self.polling_config.polling_timeout
            )

        cfg = PollingConfig(
            polling_timeout=effective_timeout,
            initial_poll_time=self.polling_config.initial_poll_time,
            exp_base=self.polling_config.exp_base,
            max_interval=self.polling_config.max_interval,
        )

        poller = self.JobPoller(
            status_fn=self._check_status_with_retry,
            job_id=job_id,
            polling_config=cfg,
        )
        return poller.run()
api_key = api_key or os.getenv('PYANNOTEAI_API_TOKEN') instance-attribute
config = config or PyannoteConfig() instance-attribute
headers = {'Authorization': f'Bearer {self.api_key}'} instance-attribute
network_timeout = self.config.network_timeout instance-attribute
polling_config = self.config.polling_config instance-attribute
upload_max_retries = self.config.upload_max_retries instance-attribute
upload_timeout = self.config.upload_timeout instance-attribute
JobPoller

Generic job polling helper for long-running async jobs.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
class JobPoller:
    """
    Generic job polling helper for long-running async jobs.
    """

    def __init__(self, status_fn, job_id: str, polling_config: PollingConfig):
        self.status_fn = status_fn
        self.job_id = job_id
        self.polling_config = polling_config
        self.poll_count = 0
        self.start_time = time.time()
        self.last_status: Optional[JobStatusResponse] = None
        self._last_error_reason: Optional[str] = None

    def _poll(self) -> JobStatusResponse | _PollSignal | None:
        self.poll_count += 1
        try:
            status_response = self.status_fn(self.job_id)
        except RetryError as e:
            self._last_error_reason = f"status check retry exhausted: {e}"
            logger.error(f"Status check retries exhausted for job {self.job_id}: {e}")
            return _PollSignal.STATUS_RETRY_EXHAUSTED

        if status_response is None:
            logger.error(f"Failed to get status for job {self.job_id} after retries")
            self._last_error_reason = "status response None"
            return None

        # track last known status for timeout / errors
        self.last_status = status_response

        status = status_response.status
        elapsed = time.time() - self.start_time

        if status == JobStatus.SUCCEEDED:
            logger.info(
                f"Job {self.job_id} completed successfully after {elapsed:.1f}s ({self.poll_count} polls)"
            )
            return status_response

        if status == JobStatus.FAILED:
            logger.error(f"Job {self.job_id} failed: {status_response.server_error_msg}")
            return status_response

        # Job still running - calculate next poll interval
        logger.info(f"Job {self.job_id} status: {status} (elapsed: {elapsed:.1f}s)")
        return _PollSignal.CONTINUE

    # --- Internal builders to attach polling context and craft JSRs ---
    def _attach_context(
        self, 
        base: Optional[JobStatusResponse], 
        *, 
        outcome: PollOutcome, 
        elapsed: float, 
        msg: Optional[str] = None
        ) -> JobStatusResponse:
        """Return a JSR carrying outcome + poll context. If `base` exists, preserve its
        status/payload/server_error_msg unless `msg` overrides it. Otherwise, synthesize a minimal JSR."""
        if base is None:
            return JobStatusResponse(
                job_id=self.job_id,
                outcome=outcome,
                status=None,
                server_error_msg=msg,
                payload=None,
                polls=self.poll_count,
                elapsed_s=elapsed,
            )
        return JobStatusResponse(
            job_id=self.job_id,
            outcome=outcome,
            status=base.status,
            server_error_msg=msg if msg is not None else base.server_error_msg,
            payload=base.payload,
            polls=self.poll_count,
            elapsed_s=elapsed,
        )

    def _on_terminal(self, jsr: JobStatusResponse, *, elapsed: float) -> JobStatusResponse:
        """Attach poll context to a terminal server response (SUCCEEDED/FAILED)."""
        return JobStatusResponse(
            job_id=self.job_id,
            outcome=PollOutcome.SUCCEEDED if jsr.status == JobStatus.SUCCEEDED else PollOutcome.FAILED,
            status=jsr.status,
            server_error_msg=jsr.server_error_msg,
            payload=jsr.payload,
            polls=self.poll_count,
            elapsed_s=elapsed,
        )

    def _on_status_retry_exhausted(self, *, elapsed: float) -> JobStatusResponse:
        return self._attach_context(
            self.last_status, 
            outcome=PollOutcome.NETWORK_ERROR, 
            elapsed=elapsed, 
            msg=self._last_error_reason
            )

    def _on_invalid_payload(self, *, elapsed: float) -> JobStatusResponse:
        return self._attach_context(
            self.last_status, 
            outcome=PollOutcome.ERROR, 
            elapsed=elapsed, 
            msg="invalid status payload"
            )

    def _on_timeout(self, err: RetryError, *, elapsed: float) -> JobStatusResponse:
        return self._attach_context(
            self.last_status, 
            outcome=PollOutcome.TIMEOUT, 
            elapsed=elapsed, 
            msg=str(err)
            )

    def _on_interrupt(self, *, elapsed: float) -> JobStatusResponse:
        return self._attach_context(
            self.last_status, 
            outcome=PollOutcome.INTERRUPTED, 
            elapsed=elapsed, 
            msg="KeyboardInterrupt"
            )

    def _on_exception(self, err: Exception, *, elapsed: float) -> JobStatusResponse:
        return self._attach_context(
            self.last_status, 
            outcome=PollOutcome.ERROR, 
            elapsed=elapsed, 
            msg=str(err)
            )

    def run(self) -> JobStatusResponse:
        try:
            result = self._setup_and_run_poll()
            elapsed = time.time() - self.start_time

            if isinstance(result, JobStatusResponse):
                # Terminal SUCCEEDED/FAILED (or unexpected non-terminal delivered): attach context
                return self._on_terminal(result, elapsed=elapsed)

            if result is _PollSignal.STATUS_RETRY_EXHAUSTED:
                return self._on_status_retry_exhausted(elapsed=elapsed)

            # None indicates invalid status payload or unexpected branch
            return self._on_invalid_payload(elapsed=elapsed)

        except RetryError as e:
            # Outer polling timeout
            elapsed = time.time() - self.start_time
            logger.info(f"Polling timed out for job {self.job_id} after {elapsed:.1f}s")
            return self._on_timeout(e, elapsed=elapsed)
        except KeyboardInterrupt:
            elapsed = time.time() - self.start_time
            logger.info(f"Polling for job {self.job_id} interrupted by user. Exiting.")
            return self._on_interrupt(elapsed=elapsed)
        except Exception as e:
            elapsed = time.time() - self.start_time
            logger.error(f"Polling failed for job {self.job_id}: {e}")
            return self._on_exception(e, elapsed=elapsed)

    def _setup_and_run_poll(self) -> Optional[JobStatusResponse | _PollSignal]:
        cfg = self.polling_config
        stop_policy = stop_never if cfg.polling_timeout is None else stop_after_delay(cfg.polling_timeout)
        retrying = Retrying(
            retry=retry_if_result(lambda result: result is _PollSignal.CONTINUE),
            stop=stop_policy,
            wait=wait_exponential_jitter(
                exp_base=cfg.exp_base,
                initial=cfg.initial_poll_time,
                max=cfg.max_interval,
            ),
            reraise=True,
        )
        result = retrying(self._poll)
        if isinstance(result, JobStatusResponse):
            return result
        # could be STATUS_RETRY_EXHAUSTED sentinel or None
        logger.info(f"Polling ended with result: {result}")
        return result
job_id = job_id instance-attribute
last_status = None instance-attribute
poll_count = 0 instance-attribute
polling_config = polling_config instance-attribute
start_time = time.time() instance-attribute
status_fn = status_fn instance-attribute
__init__(status_fn, job_id, polling_config)
Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
294
295
296
297
298
299
300
301
def __init__(self, status_fn, job_id: str, polling_config: PollingConfig):
    self.status_fn = status_fn
    self.job_id = job_id
    self.polling_config = polling_config
    self.poll_count = 0
    self.start_time = time.time()
    self.last_status: Optional[JobStatusResponse] = None
    self._last_error_reason: Optional[str] = None
run()
Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
def run(self) -> JobStatusResponse:
    try:
        result = self._setup_and_run_poll()
        elapsed = time.time() - self.start_time

        if isinstance(result, JobStatusResponse):
            # Terminal SUCCEEDED/FAILED (or unexpected non-terminal delivered): attach context
            return self._on_terminal(result, elapsed=elapsed)

        if result is _PollSignal.STATUS_RETRY_EXHAUSTED:
            return self._on_status_retry_exhausted(elapsed=elapsed)

        # None indicates invalid status payload or unexpected branch
        return self._on_invalid_payload(elapsed=elapsed)

    except RetryError as e:
        # Outer polling timeout
        elapsed = time.time() - self.start_time
        logger.info(f"Polling timed out for job {self.job_id} after {elapsed:.1f}s")
        return self._on_timeout(e, elapsed=elapsed)
    except KeyboardInterrupt:
        elapsed = time.time() - self.start_time
        logger.info(f"Polling for job {self.job_id} interrupted by user. Exiting.")
        return self._on_interrupt(elapsed=elapsed)
    except Exception as e:
        elapsed = time.time() - self.start_time
        logger.error(f"Polling failed for job {self.job_id}: {e}")
        return self._on_exception(e, elapsed=elapsed)
__init__(api_key=None, config=None)

Initialize with API key.

Parameters:

Name Type Description Default
api_key Optional[str]

Pyannote.ai API key (defaults to environment variable)

None
Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def __init__(self, api_key: Optional[str] = None, config: Optional[PyannoteConfig] = None):
    """
    Initialize with API key.

    Args:
        api_key: Pyannote.ai API key (defaults to environment variable)
    """
    self.api_key = api_key or os.getenv("PYANNOTEAI_API_TOKEN")
    if not self.api_key:
        raise APIKeyError(
            "API key is required. Set PYANNOTEAI_API_TOKEN environment "
            "variable or pass as parameter"
        )

    self.config = config or PyannoteConfig()
    self.polling_config = self.config.polling_config

    # Upload-specific timeouts (longer than general calls)
    self.upload_timeout = self.config.upload_timeout
    self.upload_max_retries = self.config.upload_max_retries
    self.network_timeout = self.config.network_timeout

    self.headers = {"Authorization": f"Bearer {self.api_key}"}
check_job_status(job_id)

Check the status of a diarization job.

Returns a typed transport model (JobStatusResponse) or None on failure.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
244
245
246
247
248
249
250
def check_job_status(self, job_id: str) -> Optional[JobStatusResponse]:
    """
    Check the status of a diarization job.

    Returns a typed transport model (JobStatusResponse) or None on failure.
    """
    return self._check_status_with_retry(job_id)
poll_job_until_complete(job_id, estimated_duration=None, timeout=None, wait_until_complete=False)

Poll until the job reaches a terminal state or a client-side stop condition, and return a unified JobStatusResponse (JSR) that includes both the server payload and polling context via outcome, polls, and elapsed_s.

Parameters:

Name Type Description Default
job_id str

Remote job identifier to poll.

required
estimated_duration Optional[float]

Optional hint; currently unused (reserved for adaptive backoff).

None
timeout Optional[float]

Optional hard timeout in seconds for this poll call. If provided, it overrides the client's default polling timeout. Ignored if wait_until_complete is True.

None
wait_until_complete Optional[bool]

If True, ignore timeout and poll indefinitely (subject to process lifetime).

False

Returns:

Name Type Description
JobStatusResponse JobStatusResponse

unified transport + polling-context result.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
def poll_job_until_complete(
    self,
    job_id: str,
    estimated_duration: Optional[float] = None,
    timeout: Optional[float] = None,
    wait_until_complete: Optional[bool] = False,
) -> JobStatusResponse:
    """
    Poll until the job reaches a terminal state or a client-side stop condition, and
    return a unified JobStatusResponse (JSR) that includes both the server payload
    and polling context via `outcome`, `polls`, and `elapsed_s`.

    Args:
        job_id: Remote job identifier to poll.
        estimated_duration: Optional hint; currently unused (reserved for adaptive backoff).
        timeout: Optional hard timeout in seconds for this poll call. If provided, it overrides
                 the client's default polling timeout. Ignored if `wait_until_complete` is True.
        wait_until_complete: If True, ignore timeout and poll indefinitely (subject to process lifetime).

    Returns:
        JobStatusResponse: unified transport + polling-context result.
    """
    if timeout is not None and wait_until_complete:
        raise ConfigurationError("Timeout cannot be set with wait_until_complete")

    # Derive an effective timeout for this call, without mutating client defaults
    effective_timeout = None if wait_until_complete else (
        timeout if timeout is not None else self.polling_config.polling_timeout
        )

    cfg = PollingConfig(
        polling_timeout=effective_timeout,
        initial_poll_time=self.polling_config.initial_poll_time,
        exp_base=self.polling_config.exp_base,
        max_interval=self.polling_config.max_interval,
    )

    poller = self.JobPoller(
        status_fn=self._check_status_with_retry,
        job_id=job_id,
        polling_config=cfg,
    )
    return poller.run()
start_diarization(media_id, params=None)

Start diarization job with pyannote.ai API.

Parameters:

Name Type Description Default
media_id str

The media ID from upload_audio

required
params Optional[DiarizationParams]

Optional parameters for diarization

None

Returns:

Type Description
Optional[str]

Optional[str]: The job ID if started successfully, None otherwise

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def start_diarization(self, media_id: str, params: Optional[DiarizationParams] = None) -> Optional[str]:
    """
    Start diarization job with pyannote.ai API.

    Args:
        media_id: The media ID from upload_audio
        params: Optional parameters for diarization

    Returns:
        Optional[str]: The job ID if started successfully, None otherwise
    """
    try:
        return self._send_payload(media_id, params)
    except requests.RequestException as e:
        logger.error(f"API request failed: {e}")
        return None
    except ValueError as e:
        logger.error(f"Invalid API response: {e}")
        return None
upload_audio(file_path)

Upload audio file with retry logic for network robustness.

Retries on network errors with exponential backoff. Fails fast on permanent errors (auth, file not found, etc.).

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_client.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential_jitter(exp_base=2, initial=3, max=30),
    retry=retry_if_exception_type(
        (requests.RequestException, requests.Timeout, requests.ConnectionError)
        ),
)
def upload_audio(self, file_path: Path) -> Optional[str]:
    """
    Upload audio file with retry logic for network robustness.

    Retries on network errors with exponential backoff.
    Fails fast on permanent errors (auth, file not found, etc.).
    """
    try:
        if not file_path.exists() or not file_path.is_file():
            logger.error(f"Audio file not found or is not a file: {file_path}")
            return None
    except OSError as e:
        logger.error(f"Error accessing audio file '{file_path}': {e}")
        return None

    try:
        file_size_mb = file_path.stat().st_size / (1024 * 1024)
    except OSError as e:
        logger.error(f"Error reading file size for '{file_path}': {e}")
        return None

    logger.info(f"Starting upload of {file_path.name} ({file_size_mb:.1f}MB)")

    try:
        # Create media ID
        media_id = self._create_media_id()
        logger.debug(f"Created media ID: {media_id}")

        # Get upload URL (this is fast, use normal timeout)
        upload_url = self._data_upload_url(media_id)
        if not upload_url:
            return None

        # Upload file (this is slow, use extended timeout)
        if self._upload_file(file_path, upload_url):
            logger.info(f"Upload completed successfully: {media_id}")
            return media_id
        else:
            logger.error(f"Upload failed for {file_path.name}")
            return None

    except Exception as e:
        # Log but don't retry - let tenacity handle retries
        logger.error(f"Upload attempt failed: {e}")
        raise  # Re-raise for tenacity to handle
pyannote_diarize
PYANNOTE_FILE_STR = '_pyannote_diarization' module-attribute
logger = get_child_logger(__name__) module-attribute
DiarizationProcessor

Orchestrator over a DiarizationService.

This layer delegates to the service for generation and handles persistence.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class DiarizationProcessor:
    """Orchestrator over a DiarizationService.

    This layer delegates to the service for generation and handles persistence.
    """

    def __init__(
        self,
        audio_file_path: Path,
        output_path: Optional[Path] = None,
        *,
        service: Optional[DiarizationService] = None,
        params: Optional[DiarizationParams] = None,
        api_key: Optional[str] = None,
        writer: Optional[ResultWriter] = None,
    ) -> None:
        self.audio_file_path: Path = audio_file_path.resolve()
        if not self.audio_file_path.exists():
            raise FileNotFoundError(f"Audio file not found: {audio_file_path}")

        # Default output path
        self.output_path: Path = (
            output_path.resolve()
            if output_path is not None
            else self.audio_file_path.parent / f"{self.audio_file_path.stem}{PYANNOTE_FILE_STR}.json"
        )

        # Service & config
        # If a concrete service is not provided, default to PyannoteService.
        # Only pass api_key to PyannoteClient if it is not None.
        default_client = PyannoteClient(api_key) if api_key is not None else PyannoteClient()
        self.service: DiarizationService = service or PyannoteService(default_client)
        self.params: Optional[DiarizationParams] = params
        self.writer: ResultWriter = writer or FileResultWriter()

        # Cached state
        self._last_response: Optional[DiarizationResponse] = None
        self._last_job_id: Optional[str] = None

    # ---- Two-phase job control (nice for UIs) --------------------------------

    def start(self) -> JobHandle:
        """Start a job and cache its job_id."""
        job_id = self.service.start(self.audio_file_path, params=self.params)
        if not job_id:
            raise RuntimeError("Diarization service returned empty job_id")
        self._last_job_id = job_id
        return JobHandle(job_id=job_id)

    def get_response(
        self, job: Optional[Union[JobHandle, str]] = None, *, wait_until_complete: bool = False
        ) -> DiarizationResponse:
        """Fetch current/final response for a job, caching the last response."""
        target_id: Optional[str]
        if isinstance(job, JobHandle):
            target_id = job.job_id
        else:
            target_id = job or self._last_job_id
        if target_id is None:
            raise ValueError(
                "No job_id provided and no previous job has been started. Call start() or pass a job_id."
            )
        resp = self.service.get_response(target_id, wait_until_complete=wait_until_complete)
        self._last_response = resp
        return resp

    # ---- One-shot path --------------------------------------------------------

    def generate(self, *, wait_until_complete: bool = True) -> DiarizationResponse:
        """One-shot convenience: delegate to the service and cache the response."""
        resp = self.service.generate(
            self.audio_file_path, 
            params=self.params, 
            wait_until_complete=wait_until_complete
            )
        self._last_response = resp
        # If the service exposes a job_id in the envelope, cache it for UIs
        # Do not fail on metadata issues; response is primary.
        try:
            job_id = getattr(resp, "job_id", None)
            if isinstance(job_id, str):
                self._last_job_id = job_id
        except (AttributeError, TypeError) as e:
            logger.warning(f"Could not extract job_id from response: {e}")
        return resp

    # ---- Persistence ----------------------------------------------------------

    def export(self, response: Optional[DiarizationResponse] = None) -> Path:
        """Write the provided or last response to `self.output_path`."""
        result = response or self._last_response
        if result is None:
            raise ValueError(
                "No DiarizationResponse available; call generate()/get_response() first or pass response="
                )
        return self.writer.write(self.output_path, result)
audio_file_path = audio_file_path.resolve() instance-attribute
output_path = output_path.resolve() if output_path is not None else self.audio_file_path.parent / f'{self.audio_file_path.stem}{PYANNOTE_FILE_STR}.json' instance-attribute
params = params instance-attribute
service = service or PyannoteService(default_client) instance-attribute
writer = writer or FileResultWriter() instance-attribute
__init__(audio_file_path, output_path=None, *, service=None, params=None, api_key=None, writer=None)
Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def __init__(
    self,
    audio_file_path: Path,
    output_path: Optional[Path] = None,
    *,
    service: Optional[DiarizationService] = None,
    params: Optional[DiarizationParams] = None,
    api_key: Optional[str] = None,
    writer: Optional[ResultWriter] = None,
) -> None:
    self.audio_file_path: Path = audio_file_path.resolve()
    if not self.audio_file_path.exists():
        raise FileNotFoundError(f"Audio file not found: {audio_file_path}")

    # Default output path
    self.output_path: Path = (
        output_path.resolve()
        if output_path is not None
        else self.audio_file_path.parent / f"{self.audio_file_path.stem}{PYANNOTE_FILE_STR}.json"
    )

    # Service & config
    # If a concrete service is not provided, default to PyannoteService.
    # Only pass api_key to PyannoteClient if it is not None.
    default_client = PyannoteClient(api_key) if api_key is not None else PyannoteClient()
    self.service: DiarizationService = service or PyannoteService(default_client)
    self.params: Optional[DiarizationParams] = params
    self.writer: ResultWriter = writer or FileResultWriter()

    # Cached state
    self._last_response: Optional[DiarizationResponse] = None
    self._last_job_id: Optional[str] = None
export(response=None)

Write the provided or last response to self.output_path.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
169
170
171
172
173
174
175
176
def export(self, response: Optional[DiarizationResponse] = None) -> Path:
    """Write the provided or last response to `self.output_path`."""
    result = response or self._last_response
    if result is None:
        raise ValueError(
            "No DiarizationResponse available; call generate()/get_response() first or pass response="
            )
    return self.writer.write(self.output_path, result)
generate(*, wait_until_complete=True)

One-shot convenience: delegate to the service and cache the response.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def generate(self, *, wait_until_complete: bool = True) -> DiarizationResponse:
    """One-shot convenience: delegate to the service and cache the response."""
    resp = self.service.generate(
        self.audio_file_path, 
        params=self.params, 
        wait_until_complete=wait_until_complete
        )
    self._last_response = resp
    # If the service exposes a job_id in the envelope, cache it for UIs
    # Do not fail on metadata issues; response is primary.
    try:
        job_id = getattr(resp, "job_id", None)
        if isinstance(job_id, str):
            self._last_job_id = job_id
    except (AttributeError, TypeError) as e:
        logger.warning(f"Could not extract job_id from response: {e}")
    return resp
get_response(job=None, *, wait_until_complete=False)

Fetch current/final response for a job, caching the last response.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def get_response(
    self, job: Optional[Union[JobHandle, str]] = None, *, wait_until_complete: bool = False
    ) -> DiarizationResponse:
    """Fetch current/final response for a job, caching the last response."""
    target_id: Optional[str]
    if isinstance(job, JobHandle):
        target_id = job.job_id
    else:
        target_id = job or self._last_job_id
    if target_id is None:
        raise ValueError(
            "No job_id provided and no previous job has been started. Call start() or pass a job_id."
        )
    resp = self.service.get_response(target_id, wait_until_complete=wait_until_complete)
    self._last_response = resp
    return resp
start()

Start a job and cache its job_id.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
122
123
124
125
126
127
128
def start(self) -> JobHandle:
    """Start a job and cache its job_id."""
    job_id = self.service.start(self.audio_file_path, params=self.params)
    if not job_id:
        raise RuntimeError("Diarization service returned empty job_id")
    self._last_job_id = job_id
    return JobHandle(job_id=job_id)
FileResultWriter

Default file-system writer to JSON.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
30
31
32
33
34
35
36
37
class FileResultWriter:
    """Default file-system writer to JSON."""

    def write(self, path: Path, response: DiarizationResponse) -> Path:
        ensure_directory_exists(path.parent)
        write_str_to_file(path, response.model_dump_json(indent=2), overwrite=True)
        logger.debug(f"DiarizationResponse saved to {path}")
        return path
write(path, response)
Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
33
34
35
36
37
def write(self, path: Path, response: DiarizationResponse) -> Path:
    ensure_directory_exists(path.parent)
    write_str_to_file(path, response.model_dump_json(indent=2), overwrite=True)
    logger.debug(f"DiarizationResponse saved to {path}")
    return path
PyannoteService

Bases: DiarizationService

Concrete implementation of DiarizationService for pyannote.ai.

Bridges transport (PyannoteClient) and mapping (PyannoteAdapter) while exposing a clean domain-facing API.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class PyannoteService(DiarizationService):
    """Concrete implementation of DiarizationService for pyannote.ai.

    Bridges transport (PyannoteClient) and mapping (PyannoteAdapter) while
    exposing a clean domain-facing API.
    """

    def __init__(self, client: Optional[PyannoteClient] = None, adapter: Optional[PyannoteAdapter] = None):
        self.client = client or PyannoteClient()
        self.adapter = adapter or PyannoteAdapter()

    # --- DiarizationService protocol ---
    def start(self, audio_path: Path, params: Optional[DiarizationParams] = None) -> str:
        media_id = self.client.upload_audio(audio_path)
        if not media_id:
            return ""
        job_id = self.client.start_diarization(media_id, params=params)
        return job_id or ""

    def get_response(self, job_id: str, *, wait_until_complete: bool = False) -> DiarizationResponse:
        jsr: Optional[JobStatusResponse]

        jsr = self.client.poll_job_until_complete(job_id, wait_until_complete)

        return self.adapter.to_response(jsr)

    def generate(
        self, 
        audio_path: Path, 
        params: Optional[DiarizationParams] = None, 
        *, 
        wait_until_complete: bool = True
        ) -> DiarizationResponse:
        if job_id := self.start(audio_path, params=params):
            return self.get_response(job_id, wait_until_complete=wait_until_complete)
        return self.adapter.failed_start()
adapter = adapter or PyannoteAdapter() instance-attribute
client = client or PyannoteClient() instance-attribute
__init__(client=None, adapter=None)
Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
47
48
49
def __init__(self, client: Optional[PyannoteClient] = None, adapter: Optional[PyannoteAdapter] = None):
    self.client = client or PyannoteClient()
    self.adapter = adapter or PyannoteAdapter()
generate(audio_path, params=None, *, wait_until_complete=True)
Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
66
67
68
69
70
71
72
73
74
75
def generate(
    self, 
    audio_path: Path, 
    params: Optional[DiarizationParams] = None, 
    *, 
    wait_until_complete: bool = True
    ) -> DiarizationResponse:
    if job_id := self.start(audio_path, params=params):
        return self.get_response(job_id, wait_until_complete=wait_until_complete)
    return self.adapter.failed_start()
get_response(job_id, *, wait_until_complete=False)
Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
59
60
61
62
63
64
def get_response(self, job_id: str, *, wait_until_complete: bool = False) -> DiarizationResponse:
    jsr: Optional[JobStatusResponse]

    jsr = self.client.poll_job_until_complete(job_id, wait_until_complete)

    return self.adapter.to_response(jsr)
start(audio_path, params=None)
Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
52
53
54
55
56
57
def start(self, audio_path: Path, params: Optional[DiarizationParams] = None) -> str:
    media_id = self.client.upload_audio(audio_path)
    if not media_id:
        return ""
    job_id = self.client.start_diarization(media_id, params=params)
    return job_id or ""
diarize(audio_file_path, output_path=None, *, params=None, service=None, api_key=None, wait_until_complete=True)

One-shot convenience to generate a result and (optionally) write it.

This returns the DiarizationResponse. Writing is left to callers or diarize_to_file below.

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def diarize(
    audio_file_path: Path,
    output_path: Optional[Path] = None,
    *,
    params: Optional[DiarizationParams] = None,
    service: Optional[DiarizationService] = None,
    api_key: Optional[str] = None,
    wait_until_complete: bool = True,
) -> DiarizationResponse:
    """One-shot convenience to generate a result and (optionally) write it.

    This returns the `DiarizationResponse`. Writing is left to callers or
    `diarize_to_file` below.
    """
    processor = DiarizationProcessor(
        audio_file_path,
        output_path=output_path,
        service=service,
        params=params,
        api_key=api_key,
    )
    return processor.generate(wait_until_complete=wait_until_complete)
diarize_to_file(audio_file_path, output_path=None, *, params=None, service=None, api_key=None, wait_until_complete=True)

Convenience helper: generate then export to JSON if successful; returns response

Source code in src/tnh_scholar/audio_processing/diarization/pyannote_diarize.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def diarize_to_file(
    audio_file_path: Path,
    output_path: Optional[Path] = None,
    *,
    params: Optional[DiarizationParams] = None,
    service: Optional[DiarizationService] = None,
    api_key: Optional[str] = None,
    wait_until_complete: bool = True,
) -> DiarizationResponse:
    """Convenience helper: generate then export to JSON if successful; returns response"""
    processor = DiarizationProcessor(
        audio_file_path,
        output_path=output_path,
        service=service,
        params=params,
        api_key=api_key,
    )
    response = processor.generate(wait_until_complete=wait_until_complete)
    if isinstance(response, DiarizationSucceeded):
        processor.export()
    return response
schemas
DiarizationResponse = Annotated[Union[DiarizationSucceeded, DiarizationFailed, DiarizationPending, DiarizationRunning], Field(discriminator='status')] module-attribute
__all__ = ['PollOutcome', 'DiarizationParams', 'StartDiarizationResponse', 'JobStatus', 'JobStatusResponse', 'ErrorCode', 'ErrorInfo', 'DiarizationResult', 'DiarizationSucceeded', 'DiarizationFailed', 'DiarizationPending', 'DiarizationRunning', 'DiarizationResponse'] module-attribute
DiarizationFailed

Bases: _BaseResponse

Source code in src/tnh_scholar/audio_processing/diarization/schemas.py
168
169
170
class DiarizationFailed(_BaseResponse):
    status: Literal["failed"]
    error: ErrorInfo
error instance-attribute
status instance-attribute
DiarizationParams

Bases: BaseModel

Per-request diarization options; maps to pyannote API payload. Use .to_api_dict() to emit API field names.

Source code in src/tnh_scholar/audio_processing/diarization/schemas.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class DiarizationParams(BaseModel):
    """
    Per-request diarization options; maps to pyannote API payload.
    Use .to_api_dict() to emit API field names.
    """

    model_config = ConfigDict(
        frozen=True,            # make instances immutable
        populate_by_name=True,  # allow using pythonic field names with aliases
        extra="forbid",         # catch accidental fields at construction
    )

    # Pythonic attribute -> API alias on dump
    num_speakers: int | Literal["auto"] | None = Field(
        default=None,
        alias="numSpeakers",
        description="Fixed number of speakers or 'auto' for detection.",
    )
    confidence: float | None = Field(
        default=None,
        ge=0.0,
        le=1.0,
        description="Confidence threshold for segments.",
    )
    webhook: AnyUrl | None = Field(
        default=None,
        description="Webhook URL for job status callbacks.",
    )

    def to_api_dict(self) -> dict[str, Any]:
        """Return payload dict using API field names (camelCase) and excluding Nones."""
        return self.model_dump(by_alias=True, exclude_none=True)
confidence = Field(default=None, ge=0.0, le=1.0, description='Confidence threshold for segments.') class-attribute instance-attribute
model_config = ConfigDict(frozen=True, populate_by_name=True, extra='forbid') class-attribute instance-attribute
num_speakers = Field(default=None, alias='numSpeakers', description="Fixed number of speakers or 'auto' for detection.") class-attribute instance-attribute
webhook = Field(default=None, description='Webhook URL for job status callbacks.') class-attribute instance-attribute
to_api_dict()

Return payload dict using API field names (camelCase) and excluding Nones.

Source code in src/tnh_scholar/audio_processing/diarization/schemas.py
65
66
67
def to_api_dict(self) -> dict[str, Any]:
    """Return payload dict using API field names (camelCase) and excluding Nones."""
    return self.model_dump(by_alias=True, exclude_none=True)
DiarizationPending

Bases: _BaseResponse

Source code in src/tnh_scholar/audio_processing/diarization/schemas.py
173
174
class DiarizationPending(_BaseResponse):
    status: Literal["pending"]
status instance-attribute
DiarizationResult

Bases: BaseModel

Domain-level diarization payload used by the rest of the system. NOTE: segments is intentionally typed as list[Any] so that it can hold your project’s DiarizedSegment instances from models.py without creating an import cycle. You can tighten this typing later to list[DiarizedSegment] and import under TYPE_CHECKING if desired.

Source code in src/tnh_scholar/audio_processing/diarization/schemas.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class DiarizationResult(BaseModel):
    """
    Domain-level diarization payload used by the rest of the system.
    NOTE: `segments` is intentionally typed as `list[Any]` so that it can
    hold your project’s `DiarizedSegment` instances from `models.py` without
    creating an import cycle. You can tighten this typing later to
    `list[DiarizedSegment]` and import under TYPE_CHECKING if desired.
    """

    model_config = ConfigDict(frozen=True, extra="ignore")

    segments: list[Any]
    num_speakers: int | None = None
    metadata: dict[str, Any] | None = None
metadata = None class-attribute instance-attribute
model_config = ConfigDict(frozen=True, extra='ignore') class-attribute instance-attribute
num_speakers = None class-attribute instance-attribute
segments instance-attribute
DiarizationRunning

Bases: _BaseResponse

Source code in src/tnh_scholar/audio_processing/diarization/schemas.py
177
178
class DiarizationRunning(_BaseResponse):
    status: Literal["running"]
status instance-attribute
DiarizationSucceeded

Bases: _BaseResponse

Source code in src/tnh_scholar/audio_processing/diarization/schemas.py
163
164
165
class DiarizationSucceeded(_BaseResponse):
    status: Literal["succeeded"]
    result: DiarizationResult
result instance-attribute
status instance-attribute
ErrorCode

Bases: str, Enum

Client- and adapter-level error taxonomy (not server statuses).

Source code in src/tnh_scholar/audio_processing/diarization/schemas.py
115
116
117
118
119
120
121
122
123
124
class ErrorCode(str, Enum):
    """Client- and adapter-level error taxonomy (not server statuses)."""

    TIMEOUT = "timeout"        # client-side polling exceeded
    CANCELLED = "cancelled"      # user/client initiated cancellation
    TRANSIENT = "transient"      # retryable infra/network issue
    BAD_REQUEST = "bad_request"  # invalid params before hitting API
    API_ERROR = "api_error"      # remote API responded with error
    PARSE_ERROR = "parse_error"  # unexpected/invalid payload shape
    UNKNOWN = "unknown"
API_ERROR = 'api_error' class-attribute instance-attribute
BAD_REQUEST = 'bad_request' class-attribute instance-attribute
CANCELLED = 'cancelled' class-attribute instance-attribute
PARSE_ERROR = 'parse_error' class-attribute instance-attribute
TIMEOUT = 'timeout' class-attribute instance-attribute
TRANSIENT = 'transient' class-attribute instance-attribute
UNKNOWN = 'unknown' class-attribute instance-attribute
ErrorInfo

Bases: BaseModel

Source code in src/tnh_scholar/audio_processing/diarization/schemas.py
127
128
129
130
131
132
class ErrorInfo(BaseModel):
    model_config = ConfigDict(frozen=True, extra="allow")

    code: ErrorCode
    message: str
    details: dict[str, Any] | None = None
code instance-attribute
details = None class-attribute instance-attribute
message instance-attribute
model_config = ConfigDict(frozen=True, extra='allow') class-attribute instance-attribute
JobHandle dataclass
Source code in src/tnh_scholar/audio_processing/diarization/schemas.py
30
31
32
33
@dataclass(frozen=True)
class JobHandle:
    job_id: str
    backend: Literal["pyannote"] = "pyannote"
backend = 'pyannote' class-attribute instance-attribute
job_id instance-attribute
__init__(job_id, backend='pyannote')
JobStatus

Bases: str, Enum

Source code in src/tnh_scholar/audio_processing/diarization/schemas.py
23
24
25
26
27
class JobStatus(str, Enum):
    SUCCEEDED = "succeeded"
    FAILED = "failed"
    PENDING = "pending"
    RUNNING = "running"
FAILED = 'failed' class-attribute instance-attribute
PENDING = 'pending' class-attribute instance-attribute
RUNNING = 'running' class-attribute instance-attribute
SUCCEEDED = 'succeeded' class-attribute instance-attribute
JobStatusResponse

Bases: BaseModel

Job Status Result (JSR): unified transport payload + client polling context. Combines transport-level fields with client-side polling metadata.

Semantics: - outcome describes how polling concluded (terminal success/failure, timeout, network error, etc.). - status is the last known server job status (SUCCEEDED, FAILED, RUNNING, PENDING) - server_error_msg and payload mirror the remote payload when present. - polls and elapsed_s report client polling metrics.

Source code in src/tnh_scholar/audio_processing/diarization/schemas.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class JobStatusResponse(BaseModel):
    """
    Job Status Result (JSR): unified transport payload + client polling context.
    Combines transport-level fields with client-side polling metadata.

    Semantics:
    - `outcome` describes how polling concluded (terminal success/failure, timeout, network error, etc.).
    - `status` is the last known *server* job status (`SUCCEEDED`, `FAILED`, `RUNNING`, `PENDING`)
    - `server_error_msg` and `payload` mirror the remote payload when present.
    - `polls` and `elapsed_s` report client polling metrics.
    """

    model_config = ConfigDict(frozen=True, extra="ignore")

    # The job id for connected to this response
    job_id: str

    # How the client-side polling finished
    outcome: PollOutcome

    # Last known server-side status (may be None if never retrieved)
    status: Optional[JobStatus] = None

    # Transport-mirrored fields (when server responded with them)
    server_error_msg: Optional[str] = None
    payload: Optional[dict[str, Any]] = None

    # Client-side polling metadata
    polls: int = 0
    elapsed_s: float = 0.0
elapsed_s = 0.0 class-attribute instance-attribute
job_id instance-attribute
model_config = ConfigDict(frozen=True, extra='ignore') class-attribute instance-attribute
outcome instance-attribute
payload = None class-attribute instance-attribute
polls = 0 class-attribute instance-attribute
server_error_msg = None class-attribute instance-attribute
status = None class-attribute instance-attribute
PollOutcome

Bases: str, Enum

Source code in src/tnh_scholar/audio_processing/diarization/schemas.py
14
15
16
17
18
19
20
class PollOutcome(str, Enum):
    SUCCEEDED = "succeeded"
    FAILED = "failed"
    TIMEOUT = "timeout"
    NETWORK_ERROR = "network_error"
    INTERRUPTED = "interrupted"
    ERROR = "error"
ERROR = 'error' class-attribute instance-attribute
FAILED = 'failed' class-attribute instance-attribute
INTERRUPTED = 'interrupted' class-attribute instance-attribute
NETWORK_ERROR = 'network_error' class-attribute instance-attribute
SUCCEEDED = 'succeeded' class-attribute instance-attribute
TIMEOUT = 'timeout' class-attribute instance-attribute
StartDiarizationResponse

Bases: BaseModel

Minimal typed view of the start-diarization response.

Source code in src/tnh_scholar/audio_processing/diarization/schemas.py
70
71
72
73
74
75
76
77
class StartDiarizationResponse(BaseModel):
    """
    Minimal typed view of the start-diarization response.
    """

    model_config = ConfigDict(frozen=True, extra="ignore")

    job_id: str = Field(alias="jobId")
job_id = Field(alias='jobId') class-attribute instance-attribute
model_config = ConfigDict(frozen=True, extra='ignore') class-attribute instance-attribute
strategies
__all__ = ['LanguageDetector', 'LanguageProbe', 'WhisperLanguageDetector', 'group_speaker_blocks', 'TimeGapChunker'] module-attribute
LanguageDetector

Bases: Protocol

Abstract language detector (e.g., fastText, Whisper-lang).

Source code in src/tnh_scholar/audio_processing/diarization/protocols.py
38
39
40
41
class LanguageDetector(Protocol):
    """Abstract language detector (e.g., fastText, Whisper-lang)."""

    def detect(self, audio: AudioSegment, format_str: str) -> Optional[str]: ...
detect(audio, format_str)
Source code in src/tnh_scholar/audio_processing/diarization/protocols.py
41
def detect(self, audio: AudioSegment, format_str: str) -> Optional[str]: ...
LanguageProbe
Source code in src/tnh_scholar/audio_processing/diarization/strategies/language_probe.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class LanguageProbe:
    def __init__(self, config: DiarizationConfig, detector: LanguageDetector):
        self.probe_time = config.language.probe_time
        self.export_format = config.language.export_format
        self.detector = detector

    def segment_language(
        self,
        aug_segment: AugDiarizedSegment,
    ) -> str:
        """
        Get segment ISO-639 language code from an Augmented Diarize Segment which contains audio.

        The probe window is always relative to the segment audio (0=start, duration=end).
        """
        probe_start, probe_end = self._calculate_probe_window(aug_segment)

        if aug_segment.audio is None:
            raise ValueError(f"Segment Audio has not been set: {aug_segment}")

        # All slicing is relative to the segment audio (0 to duration)
        audio_segment = aug_segment.audio[probe_start:probe_end]
        language = self.detector.detect(audio_segment, self.export_format)

        if language is not None:
            return language
        logger.warning(f"No language detected in language probe for segment {aug_segment}.")
        return "unknown"

    def _calculate_probe_window(
        self,
        aug_segment: AugDiarizedSegment,
    ) -> tuple[TimeMs, TimeMs]:
        """
        Calculate start/end times for language probe sampling, 
        always relative to the segment audio (0 to duration).
        """
        duration = aug_segment.duration
        if duration <= self.probe_time:
            return TimeMs(0), duration
        return self._extract_center_window(duration)

    def _extract_center_window(
        self,
        duration: TimeMs,
    ) -> tuple[TimeMs, TimeMs]:
        """
        Extract probe window from center of segment audio (relative time).
        """
        center_time = duration // 2
        half_probe = self.probe_time // 2

        probe_start = TimeMs(max(0, center_time - half_probe))
        probe_end = TimeMs(min(duration, center_time + half_probe))

        return probe_start, probe_end
detector = detector instance-attribute
export_format = config.language.export_format instance-attribute
probe_time = config.language.probe_time instance-attribute
__init__(config, detector)
Source code in src/tnh_scholar/audio_processing/diarization/strategies/language_probe.py
49
50
51
52
def __init__(self, config: DiarizationConfig, detector: LanguageDetector):
    self.probe_time = config.language.probe_time
    self.export_format = config.language.export_format
    self.detector = detector
segment_language(aug_segment)

Get segment ISO-639 language code from an Augmented Diarize Segment which contains audio.

The probe window is always relative to the segment audio (0=start, duration=end).

Source code in src/tnh_scholar/audio_processing/diarization/strategies/language_probe.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def segment_language(
    self,
    aug_segment: AugDiarizedSegment,
) -> str:
    """
    Get segment ISO-639 language code from an Augmented Diarize Segment which contains audio.

    The probe window is always relative to the segment audio (0=start, duration=end).
    """
    probe_start, probe_end = self._calculate_probe_window(aug_segment)

    if aug_segment.audio is None:
        raise ValueError(f"Segment Audio has not been set: {aug_segment}")

    # All slicing is relative to the segment audio (0 to duration)
    audio_segment = aug_segment.audio[probe_start:probe_end]
    language = self.detector.detect(audio_segment, self.export_format)

    if language is not None:
        return language
    logger.warning(f"No language detected in language probe for segment {aug_segment}.")
    return "unknown"
TimeGapChunker

Bases: ChunkingStrategy

Chunker that ignores speaker/language and uses only time-gap logic.

Source code in src/tnh_scholar/audio_processing/diarization/strategies/time_gap.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class TimeGapChunker(ChunkingStrategy):
    """Chunker that ignores speaker/language and uses only time-gap logic."""

    def __init__(self, config: DiarizationConfig = DiarizationConfig()):
        self.cfg = config

    def extract(self, segments: List[DiarizedSegment]) -> List[DiarizationChunk]:
        """Extract time-based chunks from diarization segments."""
        if not segments:
            return []

        walker = SegmentWalker(segments)
        accumulator = ChunkAccumulator(self.cfg)

        for context in walker.walk():
            if self._should_finalize_chunk(context, accumulator):
                accumulator.finalize_chunk()

            gap_time, gap_before = self._calculate_gap_info(context)
            accumulator.add_segment(context.segment, gap_time, gap_before)

        return accumulator.finalize_and_get_chunks()

    def _should_finalize_chunk(self, context, accumulator: ChunkAccumulator) -> bool:
        """Determine if current chunk should be finalized before adding segment."""
        if not accumulator.current_segments:
            return False

        gap_time, _ = self._calculate_gap_info(context)
        projected_time = accumulator.accumulated_time + context.segment.duration + gap_time

        # Don't split if remaining time would create small final chunk
        if context.remaining_time < self.cfg.chunk.min_duration:
            return False

        return projected_time >= self.cfg.chunk.target_duration

    def _calculate_gap_info(self, context) -> tuple[TimeMs, bool]:
        """Calculate gap time and gap_before flag for current segment."""
        if context.is_first:
            return TimeMs(0), False

        gap_time = context.time_interval_prev or TimeMs(0)
        gap_before = gap_time > self.cfg.chunk.gap_threshold

        # Use configured spacing for large gaps, actual gap time for small gaps
        spacing_time = TimeMs(self.cfg.chunk.gap_spacing_time) if gap_before else gap_time

        return spacing_time, gap_before
cfg = config instance-attribute
__init__(config=DiarizationConfig())
Source code in src/tnh_scholar/audio_processing/diarization/strategies/time_gap.py
24
25
def __init__(self, config: DiarizationConfig = DiarizationConfig()):
    self.cfg = config
extract(segments)

Extract time-based chunks from diarization segments.

Source code in src/tnh_scholar/audio_processing/diarization/strategies/time_gap.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def extract(self, segments: List[DiarizedSegment]) -> List[DiarizationChunk]:
    """Extract time-based chunks from diarization segments."""
    if not segments:
        return []

    walker = SegmentWalker(segments)
    accumulator = ChunkAccumulator(self.cfg)

    for context in walker.walk():
        if self._should_finalize_chunk(context, accumulator):
            accumulator.finalize_chunk()

        gap_time, gap_before = self._calculate_gap_info(context)
        accumulator.add_segment(context.segment, gap_time, gap_before)

    return accumulator.finalize_and_get_chunks()
WhisperLanguageDetector

Language detector using Whisper service.

Source code in src/tnh_scholar/audio_processing/diarization/strategies/language_probe.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class WhisperLanguageDetector:
    """Language detector using Whisper service."""

    def __init__(self, model: str = "whisper-1", audio_handler: Optional[AudioHandler] = None):
        self.model = model
        self.audio_handler = audio_handler or AudioHandler()

    def detect(self, audio: AudioSegment, format_str: str) -> Optional[str]:
        from tnh_scholar.audio_processing.transcription.whisper_service import WhisperTranscriptionService
        whisper = WhisperTranscriptionService(model=self.model, language=None, response_format="verbose_json")
        try:
            audio_bytes = self.audio_handler.export_audio_bytes(audio, format_str=format_str)
            options = patch_whisper_options(options = None, file_extension=format_str)
            result = whisper.transcribe(audio_bytes, options=options)
            logger.debug(f"full transcription result: {result}")
            return self._extract_language_from_result(result)
        except Exception as e:
            logger.warning(f"Language detection failed: {e}")
            return None

    def _extract_language_from_result(self, result) -> Optional[str]:
        """Extract language code from transcription result."""
        return getattr(result, 'language', None)
audio_handler = audio_handler or AudioHandler() instance-attribute
model = model instance-attribute
__init__(model='whisper-1', audio_handler=None)
Source code in src/tnh_scholar/audio_processing/diarization/strategies/language_probe.py
26
27
28
def __init__(self, model: str = "whisper-1", audio_handler: Optional[AudioHandler] = None):
    self.model = model
    self.audio_handler = audio_handler or AudioHandler()
detect(audio, format_str)
Source code in src/tnh_scholar/audio_processing/diarization/strategies/language_probe.py
30
31
32
33
34
35
36
37
38
39
40
41
def detect(self, audio: AudioSegment, format_str: str) -> Optional[str]:
    from tnh_scholar.audio_processing.transcription.whisper_service import WhisperTranscriptionService
    whisper = WhisperTranscriptionService(model=self.model, language=None, response_format="verbose_json")
    try:
        audio_bytes = self.audio_handler.export_audio_bytes(audio, format_str=format_str)
        options = patch_whisper_options(options = None, file_extension=format_str)
        result = whisper.transcribe(audio_bytes, options=options)
        logger.debug(f"full transcription result: {result}")
        return self._extract_language_from_result(result)
    except Exception as e:
        logger.warning(f"Language detection failed: {e}")
        return None
group_speaker_blocks(segments, config=DiarizationConfig())

Group contiguous or near-contiguous segments by speaker identity.

Segments are grouped into SpeakerBlocks when the speaker remains the same and the gap between consecutive segments is less than the configured threshold.

Parameters:

Name Type Description Default
segments List[DiarizedSegment]

A list of diarization segments (must be sorted by start time).

required
config DiarizationConfig

Configuration containing the allowed gap between segments.

DiarizationConfig()

Returns:

Type Description
List[SpeakerBlock]

A list of SpeakerBlock objects representing grouped speaker runs.

Source code in src/tnh_scholar/audio_processing/diarization/strategies/speaker_blocker.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def group_speaker_blocks(
    segments: List[DiarizedSegment],
    config: DiarizationConfig = DiarizationConfig()
) -> List[SpeakerBlock]:
    """Group contiguous or near-contiguous segments by speaker identity.

    Segments are grouped into `SpeakerBlock`s when the speaker remains the same
    and the gap between consecutive segments is less than the configured threshold.

    Parameters:
        segments: A list of diarization segments (must be sorted by start time).
        config: Configuration containing the allowed gap between segments.

    Returns:
        A list of SpeakerBlock objects representing grouped speaker runs.
    """
    if not segments:
        return []

    blocks: List[SpeakerBlock] = []
    buffer: List[DiarizedSegment] = [segments[0]]

    gap_threshold = config.speaker.same_speaker_gap_threshold

    for current in segments[1:]:
        previous = buffer[-1]
        same_speaker = current.speaker == previous.speaker
        gap = current.start - previous.end

        if same_speaker and gap <= gap_threshold:
            buffer.append(current)
        else:
            blocks.append(SpeakerBlock(speaker=buffer[0].speaker, segments=buffer))
            buffer = [current]

    if buffer:
        blocks.append(SpeakerBlock(speaker=buffer[0].speaker, segments=buffer))

    return blocks
language_based

LanguageChunker – chunking informed by speaker blocks + language probing.

logger = get_child_logger(__name__) module-attribute
LanguageChunker

Bases: ChunkingStrategy

Strategy:

  1. Group contiguous segments into SpeakerBlock objects.
  2. For each block longer than language_probe_threshold probe language at configurable offsets; if mismatch, split on language change.
  3. Build chunks respecting target_time similar to TimeGapChunker.
Source code in src/tnh_scholar/audio_processing/diarization/strategies/language_based.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class LanguageChunker(ChunkingStrategy):
    """
    Strategy:

    1. Group contiguous segments into SpeakerBlock objects.
    2. For each block longer than ``language_probe_threshold`` probe language
       at configurable offsets; if mismatch, split on language change.
    3. Build chunks respecting ``target_time`` similar to TimeGapChunker.
    """

    def __init__(
        self,
        cfg: ChunkConfig = ChunkConfig(),
        fetcher: AudioFetcher | None = None,
        detector: LanguageDetector | None = None,
        language_probe_threshold: TimeMs = TimeMs(90_000),
    ):
        self.cfg = cfg
        self.fetcher = fetcher
        self.detector = detector
        self.lang_thresh = language_probe_threshold

    def extract(self, segments: List[DiarizedSegment]) -> List[DiarizationChunk]:
        if not segments:
            return []

        blocks = group_speaker_blocks(
            segments, config=self.cfg.speaker_block
        )  # attribute from DiarizationConfig
        # Optionally split blocks on language change
        enriched_segments: List[DiarizedSegment] = []
        for block in blocks:
            if block.duration >= self.lang_thresh and self.fetcher and self.detector:
                enriched_segments.extend(self._split_block_on_language(block))
            else:
                enriched_segments.extend(block.segments)

        # Now fall back to pure time-gap chunking
        from .time_gap import TimeGapChunker

        return TimeGapChunker(self.cfg).extract(enriched_segments)


    def _split_block_on_language(self, block):
        """
        Probe language at 25% and 75% of block; if mismatch, split.
        Very naive – replace with richer algorithm later.
        """
        assert self.fetcher and self.detector  # guaranteed by caller
        first_seg = block.segments[0]
        last_seg = block.segments[-1]
        quarter_point = first_seg.start + (block.duration // 4)
        three_quarter = first_seg.start + (block.duration * 3 // 4)

        probe_segs = [self._segment_at(block, quarter_point),
                      self._segment_at(block, three_quarter)]

        langs = {probe_segment_language(s, self.fetcher, self.detector) for s in probe_segs}

        if len(langs) <= 1:
            return block.segments  # All one language

        # Language split → naively split at midpoint
        midpoint_ms = block.start + (block.duration // 2)
        left, right = [], []
        for seg in block.segments:
            (left if seg.end <= midpoint_ms else right).append(seg)

        return left + right

    def _segment_at(self, block, ms):
        """Return the first segment covering the given ms offset."""
        for seg in block.segments:
            if seg.start <= ms < seg.end:
                return seg
        return block.segments[0]  # fallback
cfg = cfg instance-attribute
detector = detector instance-attribute
fetcher = fetcher instance-attribute
lang_thresh = language_probe_threshold instance-attribute
__init__(cfg=ChunkConfig(), fetcher=None, detector=None, language_probe_threshold=TimeMs(90000))
Source code in src/tnh_scholar/audio_processing/diarization/strategies/language_based.py
33
34
35
36
37
38
39
40
41
42
43
def __init__(
    self,
    cfg: ChunkConfig = ChunkConfig(),
    fetcher: AudioFetcher | None = None,
    detector: LanguageDetector | None = None,
    language_probe_threshold: TimeMs = TimeMs(90_000),
):
    self.cfg = cfg
    self.fetcher = fetcher
    self.detector = detector
    self.lang_thresh = language_probe_threshold
extract(segments)
Source code in src/tnh_scholar/audio_processing/diarization/strategies/language_based.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def extract(self, segments: List[DiarizedSegment]) -> List[DiarizationChunk]:
    if not segments:
        return []

    blocks = group_speaker_blocks(
        segments, config=self.cfg.speaker_block
    )  # attribute from DiarizationConfig
    # Optionally split blocks on language change
    enriched_segments: List[DiarizedSegment] = []
    for block in blocks:
        if block.duration >= self.lang_thresh and self.fetcher and self.detector:
            enriched_segments.extend(self._split_block_on_language(block))
        else:
            enriched_segments.extend(block.segments)

    # Now fall back to pure time-gap chunking
    from .time_gap import TimeGapChunker

    return TimeGapChunker(self.cfg).extract(enriched_segments)
language_probe

Lightweight language-detection helpers pluggable into chunkers.

logger = get_child_logger(__name__) module-attribute
LanguageProbe
Source code in src/tnh_scholar/audio_processing/diarization/strategies/language_probe.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class LanguageProbe:
    def __init__(self, config: DiarizationConfig, detector: LanguageDetector):
        self.probe_time = config.language.probe_time
        self.export_format = config.language.export_format
        self.detector = detector

    def segment_language(
        self,
        aug_segment: AugDiarizedSegment,
    ) -> str:
        """
        Get segment ISO-639 language code from an Augmented Diarize Segment which contains audio.

        The probe window is always relative to the segment audio (0=start, duration=end).
        """
        probe_start, probe_end = self._calculate_probe_window(aug_segment)

        if aug_segment.audio is None:
            raise ValueError(f"Segment Audio has not been set: {aug_segment}")

        # All slicing is relative to the segment audio (0 to duration)
        audio_segment = aug_segment.audio[probe_start:probe_end]
        language = self.detector.detect(audio_segment, self.export_format)

        if language is not None:
            return language
        logger.warning(f"No language detected in language probe for segment {aug_segment}.")
        return "unknown"

    def _calculate_probe_window(
        self,
        aug_segment: AugDiarizedSegment,
    ) -> tuple[TimeMs, TimeMs]:
        """
        Calculate start/end times for language probe sampling, 
        always relative to the segment audio (0 to duration).
        """
        duration = aug_segment.duration
        if duration <= self.probe_time:
            return TimeMs(0), duration
        return self._extract_center_window(duration)

    def _extract_center_window(
        self,
        duration: TimeMs,
    ) -> tuple[TimeMs, TimeMs]:
        """
        Extract probe window from center of segment audio (relative time).
        """
        center_time = duration // 2
        half_probe = self.probe_time // 2

        probe_start = TimeMs(max(0, center_time - half_probe))
        probe_end = TimeMs(min(duration, center_time + half_probe))

        return probe_start, probe_end
detector = detector instance-attribute
export_format = config.language.export_format instance-attribute
probe_time = config.language.probe_time instance-attribute
__init__(config, detector)
Source code in src/tnh_scholar/audio_processing/diarization/strategies/language_probe.py
49
50
51
52
def __init__(self, config: DiarizationConfig, detector: LanguageDetector):
    self.probe_time = config.language.probe_time
    self.export_format = config.language.export_format
    self.detector = detector
segment_language(aug_segment)

Get segment ISO-639 language code from an Augmented Diarize Segment which contains audio.

The probe window is always relative to the segment audio (0=start, duration=end).

Source code in src/tnh_scholar/audio_processing/diarization/strategies/language_probe.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def segment_language(
    self,
    aug_segment: AugDiarizedSegment,
) -> str:
    """
    Get segment ISO-639 language code from an Augmented Diarize Segment which contains audio.

    The probe window is always relative to the segment audio (0=start, duration=end).
    """
    probe_start, probe_end = self._calculate_probe_window(aug_segment)

    if aug_segment.audio is None:
        raise ValueError(f"Segment Audio has not been set: {aug_segment}")

    # All slicing is relative to the segment audio (0 to duration)
    audio_segment = aug_segment.audio[probe_start:probe_end]
    language = self.detector.detect(audio_segment, self.export_format)

    if language is not None:
        return language
    logger.warning(f"No language detected in language probe for segment {aug_segment}.")
    return "unknown"
WhisperLanguageDetector

Language detector using Whisper service.

Source code in src/tnh_scholar/audio_processing/diarization/strategies/language_probe.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class WhisperLanguageDetector:
    """Language detector using Whisper service."""

    def __init__(self, model: str = "whisper-1", audio_handler: Optional[AudioHandler] = None):
        self.model = model
        self.audio_handler = audio_handler or AudioHandler()

    def detect(self, audio: AudioSegment, format_str: str) -> Optional[str]:
        from tnh_scholar.audio_processing.transcription.whisper_service import WhisperTranscriptionService
        whisper = WhisperTranscriptionService(model=self.model, language=None, response_format="verbose_json")
        try:
            audio_bytes = self.audio_handler.export_audio_bytes(audio, format_str=format_str)
            options = patch_whisper_options(options = None, file_extension=format_str)
            result = whisper.transcribe(audio_bytes, options=options)
            logger.debug(f"full transcription result: {result}")
            return self._extract_language_from_result(result)
        except Exception as e:
            logger.warning(f"Language detection failed: {e}")
            return None

    def _extract_language_from_result(self, result) -> Optional[str]:
        """Extract language code from transcription result."""
        return getattr(result, 'language', None)
audio_handler = audio_handler or AudioHandler() instance-attribute
model = model instance-attribute
__init__(model='whisper-1', audio_handler=None)
Source code in src/tnh_scholar/audio_processing/diarization/strategies/language_probe.py
26
27
28
def __init__(self, model: str = "whisper-1", audio_handler: Optional[AudioHandler] = None):
    self.model = model
    self.audio_handler = audio_handler or AudioHandler()
detect(audio, format_str)
Source code in src/tnh_scholar/audio_processing/diarization/strategies/language_probe.py
30
31
32
33
34
35
36
37
38
39
40
41
def detect(self, audio: AudioSegment, format_str: str) -> Optional[str]:
    from tnh_scholar.audio_processing.transcription.whisper_service import WhisperTranscriptionService
    whisper = WhisperTranscriptionService(model=self.model, language=None, response_format="verbose_json")
    try:
        audio_bytes = self.audio_handler.export_audio_bytes(audio, format_str=format_str)
        options = patch_whisper_options(options = None, file_extension=format_str)
        result = whisper.transcribe(audio_bytes, options=options)
        logger.debug(f"full transcription result: {result}")
        return self._extract_language_from_result(result)
    except Exception as e:
        logger.warning(f"Language detection failed: {e}")
        return None
speaker_blocker
group_speaker_blocks(segments, config=DiarizationConfig())

Group contiguous or near-contiguous segments by speaker identity.

Segments are grouped into SpeakerBlocks when the speaker remains the same and the gap between consecutive segments is less than the configured threshold.

Parameters:

Name Type Description Default
segments List[DiarizedSegment]

A list of diarization segments (must be sorted by start time).

required
config DiarizationConfig

Configuration containing the allowed gap between segments.

DiarizationConfig()

Returns:

Type Description
List[SpeakerBlock]

A list of SpeakerBlock objects representing grouped speaker runs.

Source code in src/tnh_scholar/audio_processing/diarization/strategies/speaker_blocker.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def group_speaker_blocks(
    segments: List[DiarizedSegment],
    config: DiarizationConfig = DiarizationConfig()
) -> List[SpeakerBlock]:
    """Group contiguous or near-contiguous segments by speaker identity.

    Segments are grouped into `SpeakerBlock`s when the speaker remains the same
    and the gap between consecutive segments is less than the configured threshold.

    Parameters:
        segments: A list of diarization segments (must be sorted by start time).
        config: Configuration containing the allowed gap between segments.

    Returns:
        A list of SpeakerBlock objects representing grouped speaker runs.
    """
    if not segments:
        return []

    blocks: List[SpeakerBlock] = []
    buffer: List[DiarizedSegment] = [segments[0]]

    gap_threshold = config.speaker.same_speaker_gap_threshold

    for current in segments[1:]:
        previous = buffer[-1]
        same_speaker = current.speaker == previous.speaker
        gap = current.start - previous.end

        if same_speaker and gap <= gap_threshold:
            buffer.append(current)
        else:
            blocks.append(SpeakerBlock(speaker=buffer[0].speaker, segments=buffer))
            buffer = [current]

    if buffer:
        blocks.append(SpeakerBlock(speaker=buffer[0].speaker, segments=buffer))

    return blocks
time_gap

TimeGapChunker – baseline strategy: split purely on accumulated time.

logger = get_child_logger(__name__) module-attribute
TimeGapChunker

Bases: ChunkingStrategy

Chunker that ignores speaker/language and uses only time-gap logic.

Source code in src/tnh_scholar/audio_processing/diarization/strategies/time_gap.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class TimeGapChunker(ChunkingStrategy):
    """Chunker that ignores speaker/language and uses only time-gap logic."""

    def __init__(self, config: DiarizationConfig = DiarizationConfig()):
        self.cfg = config

    def extract(self, segments: List[DiarizedSegment]) -> List[DiarizationChunk]:
        """Extract time-based chunks from diarization segments."""
        if not segments:
            return []

        walker = SegmentWalker(segments)
        accumulator = ChunkAccumulator(self.cfg)

        for context in walker.walk():
            if self._should_finalize_chunk(context, accumulator):
                accumulator.finalize_chunk()

            gap_time, gap_before = self._calculate_gap_info(context)
            accumulator.add_segment(context.segment, gap_time, gap_before)

        return accumulator.finalize_and_get_chunks()

    def _should_finalize_chunk(self, context, accumulator: ChunkAccumulator) -> bool:
        """Determine if current chunk should be finalized before adding segment."""
        if not accumulator.current_segments:
            return False

        gap_time, _ = self._calculate_gap_info(context)
        projected_time = accumulator.accumulated_time + context.segment.duration + gap_time

        # Don't split if remaining time would create small final chunk
        if context.remaining_time < self.cfg.chunk.min_duration:
            return False

        return projected_time >= self.cfg.chunk.target_duration

    def _calculate_gap_info(self, context) -> tuple[TimeMs, bool]:
        """Calculate gap time and gap_before flag for current segment."""
        if context.is_first:
            return TimeMs(0), False

        gap_time = context.time_interval_prev or TimeMs(0)
        gap_before = gap_time > self.cfg.chunk.gap_threshold

        # Use configured spacing for large gaps, actual gap time for small gaps
        spacing_time = TimeMs(self.cfg.chunk.gap_spacing_time) if gap_before else gap_time

        return spacing_time, gap_before
cfg = config instance-attribute
__init__(config=DiarizationConfig())
Source code in src/tnh_scholar/audio_processing/diarization/strategies/time_gap.py
24
25
def __init__(self, config: DiarizationConfig = DiarizationConfig()):
    self.cfg = config
extract(segments)

Extract time-based chunks from diarization segments.

Source code in src/tnh_scholar/audio_processing/diarization/strategies/time_gap.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def extract(self, segments: List[DiarizedSegment]) -> List[DiarizationChunk]:
    """Extract time-based chunks from diarization segments."""
    if not segments:
        return []

    walker = SegmentWalker(segments)
    accumulator = ChunkAccumulator(self.cfg)

    for context in walker.walk():
        if self._should_finalize_chunk(context, accumulator):
            accumulator.finalize_chunk()

        gap_time, gap_before = self._calculate_gap_info(context)
        accumulator.add_segment(context.segment, gap_time, gap_before)

    return accumulator.finalize_and_get_chunks()
timeline_mapper

Timeline mapping utilities for transforming timestamps from chunk-relative coordinates to original audio coordinates.

This module enables mapping transcript segments back to their original positions in the source audio after processing chunked audio.

logger = get_child_logger(__name__) module-attribute
TimelineMapper

Maps timestamps from chunk-relative coordinates to original audio coordinates.

Source code in src/tnh_scholar/audio_processing/diarization/timeline_mapper.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
class TimelineMapper:
    """Maps timestamps from chunk-relative coordinates to original audio coordinates."""

    def __init__(self, config: Optional[TimelineMapperConfig] = None):
        """Initialize with optional configuration."""
        self.config = config or TimelineMapperConfig()

    def remap(self, timed_text: TimedText, chunk: DiarizationChunk) -> TimedText:
        """
        Remap all timestamps in a TimedText object from chunk-relative to original audio coordinates.

        Args:
            timed_text: TimedText with chunk-relative timestamps
            chunk: DiarizationChunk containing mapping information

        Returns:
            New TimedText object with remapped timestamps
        """

        self._validate_diarize_segments(chunk)
        mapper = self._TimeUnitMapper(chunk.segments, self.config)
        self._validate_timed_text(timed_text)    

        if timed_text.segments:    
            timed_text = mapper.map_timed_text(timed_text)

        if timed_text.words:
            timed_text = mapper.map_timed_text(timed_text)

        return timed_text

    def _validate_diarize_segments(self, chunk: DiarizationChunk):
        if not (segments:=chunk.segments):
            logger.error("Empty segments.")
            raise ValueError("Cannot remap with empty chunk segments.")

        # Validate segments
        for segment in segments:
            if segment.audio_map_start is None:
                raise ValueError(f"Remap not possible. Segment {segment} is missing audio_map_time.")
            segment.normalize() 

    def _validate_timed_text(self, timed_text: TimedText):
        timed_text.sort_by_start()
        for timed_unit in timed_text.iter():
            timed_unit.normalize()

    class _TimeUnitMapper:
        """Internal helper class for time-unit mapping."""

        def __init__(self, map_segments: List[DiarizedSegment], config: TimelineMapperConfig):
            self.map_segments = map_segments
            self.config = config            

        def map_timed_text(self, tt: TimedText) -> TimedText:
            """Map timestamps in all TimedTextUnit collections contained in the TimedText object."""
            new_tt = tt.model_copy(deep=True)

            sources: List[Tuple[str, List[TimedTextUnit]]] = []
            if tt.is_segment_granularity():
                sources.append(("segments", tt.segments))
            if tt.is_word_granularity():
                sources.append(("words", tt.words))

            for attr, units in sources:
                mapped_units = [self._map_text_unit(u) for u in units]
                setattr(new_tt, attr, mapped_units)

            return new_tt

        def _map_text_unit(self, unit: TimedTextUnit) -> TimedTextUnit:
            """Map a single TimedTextUnit's timestamps."""
            # Find the best matching segment
            best_segment = self._find_best_segment(unit)

            # Debug logging for mapping decision
            if self.config.debug_logging:
                self._log_mapping_choice(unit, best_segment)

            # Apply mapping transformation and return new unit
            return self._apply_mapping(
                unit,
                best_segment
            )


        def _log_mapping_choice(self, unit, segment):
            logger.info(
                    f"Mapping unit (start: {unit.start_ms}, end: {unit.end_ms}) "
                    f"to segment (start: {segment.start}, end: {segment.end}, "
                    f"mapped_start: {segment.mapped_start}, mapped_end: {segment.mapped_end})"
                )

        def _find_best_segment(self, unit: TimedTextUnit) -> DiarizedSegment:
            """
            Find the best segment to use for mapping a TimedTextUnit.

            First tries to find segments with direct overlap.
            If none, finds proximal segments and chooses the closest.
            """
            if overlapping := self._find_overlapping_segments(unit):
                return self._choose_best_overlap(unit, overlapping)

            # If no overlaps, find proximal segments
            before, after = self._find_proximal_segments(unit)

            if before is None and after is None:
                raise ValueError("A before or after segment was not found.")

            if before is None:
                return after # type: ignore
            if after is None:
                return before

            # Choose closest proximal segment
            return self._choose_closest_proximal(unit, before, after)

        def _find_overlapping_segments(self, unit: TimedTextUnit) -> List[DiarizedSegment]:
            """Find all segments that overlap with the given unit."""
            return [
                segment for segment in self.map_segments
                if (segment.mapped_start <= unit.end_ms and
                    segment.mapped_end >= unit.start_ms)
            ]

        def _choose_best_overlap(
            self, 
            unit: TimedTextUnit, 
            candidates: List[DiarizedSegment]
        ) -> DiarizedSegment:
            """Choose the segment with the largest overlap with the unit."""
            best_segment = candidates[0]
            best_overlap = self._calculate_overlap(unit, best_segment)

            for segment in candidates[1:]:
                overlap = self._calculate_overlap(unit, segment)
                if overlap > best_overlap:
                    best_overlap = overlap
                    best_segment = segment

            return best_segment

        def _calculate_overlap(
            self, 
            unit: TimedTextUnit, 
            segment: DiarizedSegment
        ) -> int:
            """Calculate the amount of overlap between a unit and a segment in milliseconds."""     
            overlap_start = max(unit.start_ms, segment.mapped_start)
            overlap_end = min(unit.end_ms, segment.mapped_end)

            return max(0, overlap_end - overlap_start)

        def _find_proximal_segments(
            self, 
            unit: TimedTextUnit
        ) -> Tuple[Optional[DiarizedSegment], Optional[DiarizedSegment]]:
            """Find the nearest segments before and after the unit."""
            before = None
            before_end = float('-inf')
            after = None
            after_start = float('inf')

            for segment in self.map_segments:
                # Check if segment ends before unit starts
                if segment.mapped_end <= unit.start_ms and segment.mapped_end > before_end:
                    before = segment
                    before_end = segment.mapped_end

                # Check if segment starts after unit ends
                if segment.mapped_start >= unit.end_ms and segment.mapped_start < after_start:
                    after = segment
                    after_start = segment.mapped_start

            if not (before or after):
                raise ValueError("Before or after segments not found.")

            return before, after

        def _choose_closest_proximal(
            self,
            unit: TimedTextUnit,
            before: DiarizedSegment,
            after: DiarizedSegment
        ) -> DiarizedSegment:
            """
            Choose the closest proximal segment based on gap distance.
            Requires both before and after segments (cannot be None)
            """
            before_gap = unit.start_ms - before.mapped_end
            after_gap = after.mapped_start - unit.end_ms

            # Choose segment with smaller gap
            return before if before_gap <= after_gap else after

        def _apply_mapping(
            self, 
            unit: TimedTextUnit,
            segment: DiarizedSegment
        ) -> TimedTextUnit:
            """
            Apply the timeline mapping transformation.

            Maps unit timestamps from chunk-relative to original timeline.
            Returns the mapped (start_ms, end_ms) tuple.
            """
            # Calculate offset from segment's local start
            offset = unit.start_ms - segment.mapped_start

            # Apply offset to original timeline
            new_unit_start = segment.start + offset

            # Preserve duration
            duration = unit.duration_ms
            new_unit_end = new_unit_start + duration

            unit = unit.model_copy(
                update={"start_ms": new_unit_start, "end_ms": new_unit_end}
            )

            if self.config.map_speakers:
                unit.set_speaker(segment.speaker)

            return unit
config = config or TimelineMapperConfig() instance-attribute
__init__(config=None)

Initialize with optional configuration.

Source code in src/tnh_scholar/audio_processing/diarization/timeline_mapper.py
38
39
40
def __init__(self, config: Optional[TimelineMapperConfig] = None):
    """Initialize with optional configuration."""
    self.config = config or TimelineMapperConfig()
remap(timed_text, chunk)

Remap all timestamps in a TimedText object from chunk-relative to original audio coordinates.

Parameters:

Name Type Description Default
timed_text TimedText

TimedText with chunk-relative timestamps

required
chunk DiarizationChunk

DiarizationChunk containing mapping information

required

Returns:

Type Description
TimedText

New TimedText object with remapped timestamps

Source code in src/tnh_scholar/audio_processing/diarization/timeline_mapper.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def remap(self, timed_text: TimedText, chunk: DiarizationChunk) -> TimedText:
    """
    Remap all timestamps in a TimedText object from chunk-relative to original audio coordinates.

    Args:
        timed_text: TimedText with chunk-relative timestamps
        chunk: DiarizationChunk containing mapping information

    Returns:
        New TimedText object with remapped timestamps
    """

    self._validate_diarize_segments(chunk)
    mapper = self._TimeUnitMapper(chunk.segments, self.config)
    self._validate_timed_text(timed_text)    

    if timed_text.segments:    
        timed_text = mapper.map_timed_text(timed_text)

    if timed_text.words:
        timed_text = mapper.map_timed_text(timed_text)

    return timed_text
TimelineMapperConfig

Bases: BaseModel

Configuration options for timeline mapping.

Source code in src/tnh_scholar/audio_processing/diarization/timeline_mapper.py
22
23
24
25
26
27
28
29
30
31
32
class TimelineMapperConfig(BaseModel):
    """Configuration options for timeline mapping."""

    debug_logging: bool = Field(
        default=False,
        description="Enable detailed logging of mapping decisions"
    )
    map_speakers: bool = Field(
        default=True,
        description="Assign speaker to mapped timings using diarization segment speaker."
    )
debug_logging = Field(default=False, description='Enable detailed logging of mapping decisions') class-attribute instance-attribute
map_speakers = Field(default=True, description='Assign speaker to mapped timings using diarization segment speaker.') class-attribute instance-attribute
types
PyannoteEntry

Bases: TypedDict

Source code in src/tnh_scholar/audio_processing/diarization/types.py
4
5
6
7
class PyannoteEntry(TypedDict):
    speaker: str
    start: float  # seconds
    end: float    # seconds
end instance-attribute
speaker instance-attribute
start instance-attribute
viewer
close_segment_viewer(pid)

Terminate the Streamlit viewer process by PID.

Source code in src/tnh_scholar/audio_processing/diarization/viewer.py
44
45
46
47
48
49
50
def close_segment_viewer(pid: int):
    """Terminate the Streamlit viewer process by PID."""
    try:
        os.kill(pid, signal.SIGTERM)
        print(f"Closed Streamlit viewer (PID {pid})")
    except Exception as e:
        print(f"Failed to close Streamlit viewer (PID {pid}): {e}")
launch_segment_viewer(segments, master_audio_file)

Export segment data to a temporary JSON file and launch Streamlit viewer. Args: segments: List of dicts with diarization info (start, end, speaker). master_audio_file: Path to the master audio file.

Source code in src/tnh_scholar/audio_processing/diarization/viewer.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def launch_segment_viewer(segments: List[SpeakerBlock], master_audio_file: Path):
    """
    Export segment data to a temporary JSON file and launch Streamlit viewer.
    Args:
        segments: List of dicts with diarization info (start, end, speaker).
        master_audio_file: Path to the master audio file.
    """
    # Attach master audio file path to metadata
    meta = {"master_audio": str(master_audio_file)}
    serial_segments = [segment.to_dict() for segment in segments]
    payload = {"segments": serial_segments, "meta": meta}
    with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
        json.dump(payload, f)
        temp_path = f.name
    cmd = [sys.executable, "-m", "streamlit", "run", str(Path(__file__).resolve()), "--", temp_path]
    print(f"Launching Streamlit viewer with data: {temp_path}")
    proc = subprocess.Popen(cmd)
    return proc.pid
load_segments_from_file(path)
Source code in src/tnh_scholar/audio_processing/diarization/viewer.py
53
54
55
def load_segments_from_file(path):
    with open(path, "r") as f:
        return json.load(f)
main()
Source code in src/tnh_scholar/audio_processing/diarization/viewer.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def main():
    # If a data file is passed as argument, load it
    segments = None
    meta = None
    error_msg = None
    if len(sys.argv) > 1 and os.path.exists(sys.argv[-1]):
        try:
            payload = load_segments_from_file(sys.argv[-1])
            segments = payload.get("segments")
            meta = payload.get("meta")
        except Exception as e:
            error_msg = f"Failed to load segment data: {e}"
    else:
        st.error("No segment data file provided. This viewer requires explicit segment and audio file input.")
        st.stop()

    if error_msg:
        st.error(error_msg)
        st.stop()

    if not segments or not meta or not meta.get("master_audio"):
        st.error("Segments and master audio file must be provided.")
        st.stop()

    master_audio_path = meta["master_audio"]

    # --- Deserialize SpeakerBlocks from dicts ---
    blocks = [SpeakerBlock.from_dict(seg) for seg in segments]

    # Enable wide mode for Streamlit app
    st.set_page_config(layout="wide")
    st.write("## Segment Timeline Plot (seconds)")
    if not blocks:
        st.error("No segment blocks found.")
        st.stop()


    # --- Timeline Plot: group by speaker, color by speaker, number blocks ---
    try:
        speakers = list({block.speaker for block in blocks})
        color_map = {
            spk: pc.qualitative.Plotly[i % len(pc.qualitative.Plotly)] for i, spk in enumerate(speakers)
        }

        fig = go.Figure()
        speaker_blocks = defaultdict(list)
        for idx, block in enumerate(blocks):
            speaker_blocks[block.speaker].append((idx, block))

        bar_thickness = 0.6
        for speaker, items in speaker_blocks.items():
            y_val = speaker
            for idx, block in items:
                start_sec = block.start.to_seconds()
                duration_sec = block.duration.to_seconds()
                fig.add_trace(go.Bar(
                    x=[duration_sec],
                    y=[y_val],
                    base=[start_sec],
                    orientation='h',
                    marker_color=color_map[speaker],
                    name=f"{idx+1}: {speaker}",
                    hovertext=f"{idx+1}: {speaker} ({start_sec:.2f}s-{start_sec+duration_sec:.2f}s)",
                    width=bar_thickness
                ))
        fig.update_layout(
            title="All Segments (seconds)",
            xaxis_title="Time (seconds)",
            yaxis_title="Speaker",
            showlegend=False,
            bargap=0.2,
            barmode="overlay"
        )
        st.plotly_chart(fig, use_container_width=True)
    except Exception as e:
        st.error(f"Error generating timeline plot: {e}")


    # --- Segment selection via entry box ---
    st.write("## Enter Segment Number to Play")
    max_segment = len(blocks)
    segment_num = st.number_input(
        "Segment number (1-based)",
        min_value=1,
        max_value=max_segment,
        value=1,
        step=1,
        help=f"Enter a segment number between 1 and {max_segment}"
    )
    selected_idx = segment_num - 1

    block = blocks[selected_idx]
    start_ms = block.start.to_ms()
    end_ms = block.end.to_ms()
    st.write(f"Selected Segment: {segment_num} | Speaker: {block.speaker}")
    st.write(
        f"Start: {block.start.to_seconds():.2f}s, "
        f"End: {block.end.to_seconds():.2f}s, "
        f"Duration: {block.duration.to_seconds():.2f}s"
    )

    # --- Play audio for selected segment ---
    try:
        audio = AudioSegment.from_file(master_audio_path)
        segment_audio = audio[start_ms:end_ms]
        buf = io.BytesIO()
        segment_audio.export(buf, format="wav")
        st.audio(buf.getvalue(), format="audio/wav")
    except Exception as e:
        st.error(f"Error extracting or playing audio segment: {e}")

timed_object

__all__ = ['Granularity', 'TimedText', 'TimedTextUnit'] module-attribute
Granularity

Bases: str, Enum

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
19
20
21
class Granularity(str, Enum):
    SEGMENT = "segment"
    WORD = "word"
SEGMENT = 'segment' class-attribute instance-attribute
WORD = 'word' class-attribute instance-attribute
TimedText

Bases: BaseModel

Represents a collection of timed text units of a single granularity.

Only one of segments or words is populated, determined by granularity. All units must match the declared granularity.

Notes
  • Start times must be non-decreasing (overlaps allowed for multiple speakers).
  • Negative start_ms or end_ms values are not allowed.
  • Durations must be strictly positive (>0 ms).
  • Mixed granularity is strictly prohibited.
Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
class TimedText(BaseModel):
    """
    Represents a collection of timed text units of a single granularity.

    Only one of `segments` or `words` is populated, determined by `granularity`.
    All units must match the declared granularity.

    Notes:
        - Start times must be non-decreasing (overlaps allowed for multiple speakers).
        - Negative start_ms or end_ms values are not allowed.
        - Durations must be strictly positive (>0 ms).
        - Mixed granularity is strictly prohibited.
    """

    granularity: Granularity = Field(..., description="Granularity type for all units.")
    segments: List[TimedTextUnit] = Field(default_factory=list, description="Phrase-level timed units")
    words: List[TimedTextUnit] = Field(default_factory=list, description="Word-level timed units")

    def __init__(
        self,
        *,
        granularity: Optional[Granularity] = None,
        segments: Optional[List[TimedTextUnit]] = None,
        words: Optional[List[TimedTextUnit]] = None,
        units: Optional[List[TimedTextUnit]] = None,
        **kwargs
    ):
        """
        Custom initializer for TimedText.
        If `units` is provided, granularity is inferred from the first unit unless explicitly set.
        If only `segments` or `words` is provided, granularity is set accordingly.
        If all are empty, granularity must be provided.
        """
        segments = segments or []
        words = words or []
        if units is not None:
            if units:
                inferred_granularity = units[0].granularity
                granularity = granularity or inferred_granularity
                if granularity == Granularity.SEGMENT:
                    segments = units
                    words = []
                elif granularity == Granularity.WORD:
                    words = units
                    segments = []
                else:
                    raise ValueError("Invalid granularity inferred from units.")
            else:
                if granularity is None:
                    raise ValueError("Must provide granularity for empty TimedText.")
        elif segments:
            granularity = granularity or Granularity.SEGMENT
            words = []
        elif words:
            granularity = granularity or Granularity.WORD
            segments = []
        elif granularity is None:
            raise ValueError("Must provide granularity for empty TimedText.")

        super().__init__(granularity=granularity, segments=segments, words=words, **kwargs)

    @model_validator(mode="after")
    def _validate_exclusive_granularity(self):
        """
        Validate that TimedText contains only units matching its granularity.
        Allows empty TimedText objects for prototyping and construction.
        Modular logic for segments and words.
        """
        granularity = self.granularity
        segments = self.segments
        words = self.words

        if granularity == Granularity.SEGMENT:
            if words:
                raise ValueError("TimedText with SEGMENT granularity must not have word units.")
            for unit in segments:
                if unit.granularity != Granularity.SEGMENT:
                    raise ValueError("All segment units must have granularity SEGMENT.")
        elif granularity == Granularity.WORD:
            if segments:
                raise ValueError("TimedText with WORD granularity must not have segment units.")
            for unit in words:
                if unit.granularity != Granularity.WORD:
                    raise ValueError("All word units must have granularity WORD.")
        else:
            raise ValueError("Invalid granularity type.")
        return self

    def model_post_init(self, __context) -> None:
        """
        After initialization, sort units by start time and normalize durations.
        """
        self.sort_by_start()
        for unit in self.units:
            unit.normalize()

    @property
    def units(self) -> List[TimedTextUnit]:
        """Return the list of units matching the granularity."""
        return self.segments if self.granularity == Granularity.SEGMENT else self.words

    def is_segment_granularity(self) -> bool:
        """Return True if granularity is SEGMENT."""
        return self.granularity == Granularity.SEGMENT

    def is_word_granularity(self) -> bool:
        """Return True if granularity is WORD."""
        return self.granularity == Granularity.WORD

    @property
    def start_ms(self) -> int:
        """Get the start time of the earliest unit."""
        return min(unit.start_ms for unit in self.units) if self.units else 0

    @property
    def end_ms(self) -> int:
        """Get the end time of the latest unit."""
        return max(unit.end_ms for unit in self.units) if self.units else 0

    @property
    def duration(self) -> int:
        """Get the total duration in milliseconds."""
        return self.end_ms - self.start_ms

    def __len__(self) -> int:
        """Return the number of units."""
        return len(self.units)

    def append(self, unit: TimedTextUnit):
        """Add a unit to the end."""
        if unit.granularity != self.granularity:
            raise ValueError(f"Cannot append unit with granularity {unit.granularity} "
                             "to TimedText of granularity {self.granularity}.")
        self.units.append(unit)

    def extend(self, units: List[TimedTextUnit]):
        """Add multiple units to the end."""
        for unit in units:
            self.append(unit)

    def clear(self):
        """Remove all units."""
        self.units.clear()

    def set_speaker(self, index: int, speaker: str) -> None:
        """Set speaker for a specific unit by index."""
        if not (0 <= index < len(self.units)):
            raise IndexError(f"Index {index} out of range for units.")
        self.units[index].set_speaker(speaker)

    def set_all_speakers(self, speaker: str) -> None:
        """Set the same speaker for all units."""
        for unit in self.units:
            unit.set_speaker(speaker)

    def shift(self, offset_ms: int) -> None:
        """Shift all units by a given offset in milliseconds."""
        for i, unit in enumerate(self.units):
            self.units[i] = unit.shift_time(offset_ms)

    def sort_by_start(self) -> None:
        """Sort units by start time."""
        self.units.sort(key=lambda unit: unit.start_ms)


    @classmethod
    def _new_with_units(
        cls, units: List[TimedTextUnit], granularity: Optional[Granularity] = None
    ) -> "TimedText":
        """
        Helper to create a new TimedText object with the given granularity and units.
        If granularity is not provided, it is inferred from the first unit.
        """
        if units:
            inferred_granularity = units[0].granularity
            granularity = granularity or inferred_granularity
            if granularity == Granularity.SEGMENT:
                return cls(granularity=granularity, segments=units, words=[])
            elif granularity == Granularity.WORD:
                return cls(granularity=granularity, segments=[], words=units)
            else:
                raise ValueError("Invalid granularity inferred from units.")
        else:
            if granularity is None:
                raise ValueError("Must provide granularity for empty TimedText.")
            if granularity in [Granularity.SEGMENT, Granularity.WORD]:
                return cls(granularity=granularity, segments=[], words=[])
            else:
                raise ValueError("Invalid granularity provided.")

    def slice(self, start_ms: int, end_ms: int) -> "TimedText":
        """
        Return a new TimedText object containing only units within [start_ms, end_ms].
        Units must overlap with the interval to be included.
        """
        sliced_units = [
            unit for unit in self.units
            if unit.end_ms > start_ms and unit.start_ms < end_ms
        ]
        return self._new_with_units(sliced_units, self.granularity)

    def filter_by_min_duration(self, min_duration_ms: int) -> "TimedText":
        """
        Return a new TimedText object containing only units with a minimum duration.
        """
        filtered_units = [
            unit for unit in self.units
            if unit.duration_ms >= min_duration_ms
        ]
        return self._new_with_units(filtered_units, self.granularity)

    @classmethod
    def merge(cls, items: List["TimedText"]) -> "TimedText":
        """
        Merge a list of TimedText objects of the same granularity into a single TimedText object.
        """
        if not items:
            raise ValueError("No TimedText objects to merge.")
        granularity = items[0].granularity
        for item in items:
            if item.granularity != granularity:
                raise ValueError("Cannot merge TimedText objects of different granularities.")
        all_units: List[TimedTextUnit] = []
        for item in items:
            all_units.extend(item.units)

        # Use the classmethod to generate with units
        return cls._new_with_units(all_units, granularity)

    def iter(self) -> Iterator[TimedTextUnit]:
        """
        Unified iterator over the units of the correct granularity.
        """
        return iter(self.units)

    def iter_segments(self) -> Iterator[TimedTextUnit]:
        """
        Iterate over segment-level units.

        Raises:
            ValueError: If granularity is not SEGMENT.
        """
        if not self.is_segment_granularity():
            raise ValueError("Cannot call iter_segments() on TimedText with WORD granularity.")
        return iter(self.segments)

    def iter_words(self) -> Iterator[TimedTextUnit]:
        """
        Iterate over word-level units.

        Raises:
            ValueError: If granularity is not WORD.
        """
        if not self.is_word_granularity():
            raise ValueError("Cannot call iter_words() on TimedText with SEGMENT granularity.")
        return iter(self.words)

    def export_text(self, separator: str = "\n", skip_empty: bool = True, show_speaker=True) -> str:
        """
        Export the text content of all units as a single string.

        Args:
            separator: String used to separate units (default: newline).
            skip_empty: If True, skip units with empty or whitespace-only text.
            show_speaker: If True, add speaker info.

        Returns:
            Concatenated text of all units, separated by `separator`.
        """
        def _text_line(unit: TimedTextUnit) -> str:
            if show_speaker and unit.speaker:
                return f"[{unit.speaker}] {unit.text}"
            return unit.text

        texts = [
            _text_line(unit) for unit in self.units
            if not skip_empty or unit.text.strip()
        ]
        return separator.join(texts)
duration property

Get the total duration in milliseconds.

end_ms property

Get the end time of the latest unit.

granularity = Field(..., description='Granularity type for all units.') class-attribute instance-attribute
segments = Field(default_factory=list, description='Phrase-level timed units') class-attribute instance-attribute
start_ms property

Get the start time of the earliest unit.

units property

Return the list of units matching the granularity.

words = Field(default_factory=list, description='Word-level timed units') class-attribute instance-attribute
__init__(*, granularity=None, segments=None, words=None, units=None, **kwargs)

Custom initializer for TimedText. If units is provided, granularity is inferred from the first unit unless explicitly set. If only segments or words is provided, granularity is set accordingly. If all are empty, granularity must be provided.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def __init__(
    self,
    *,
    granularity: Optional[Granularity] = None,
    segments: Optional[List[TimedTextUnit]] = None,
    words: Optional[List[TimedTextUnit]] = None,
    units: Optional[List[TimedTextUnit]] = None,
    **kwargs
):
    """
    Custom initializer for TimedText.
    If `units` is provided, granularity is inferred from the first unit unless explicitly set.
    If only `segments` or `words` is provided, granularity is set accordingly.
    If all are empty, granularity must be provided.
    """
    segments = segments or []
    words = words or []
    if units is not None:
        if units:
            inferred_granularity = units[0].granularity
            granularity = granularity or inferred_granularity
            if granularity == Granularity.SEGMENT:
                segments = units
                words = []
            elif granularity == Granularity.WORD:
                words = units
                segments = []
            else:
                raise ValueError("Invalid granularity inferred from units.")
        else:
            if granularity is None:
                raise ValueError("Must provide granularity for empty TimedText.")
    elif segments:
        granularity = granularity or Granularity.SEGMENT
        words = []
    elif words:
        granularity = granularity or Granularity.WORD
        segments = []
    elif granularity is None:
        raise ValueError("Must provide granularity for empty TimedText.")

    super().__init__(granularity=granularity, segments=segments, words=words, **kwargs)
__len__()

Return the number of units.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
234
235
236
def __len__(self) -> int:
    """Return the number of units."""
    return len(self.units)
append(unit)

Add a unit to the end.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
238
239
240
241
242
243
def append(self, unit: TimedTextUnit):
    """Add a unit to the end."""
    if unit.granularity != self.granularity:
        raise ValueError(f"Cannot append unit with granularity {unit.granularity} "
                         "to TimedText of granularity {self.granularity}.")
    self.units.append(unit)
clear()

Remove all units.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
250
251
252
def clear(self):
    """Remove all units."""
    self.units.clear()
export_text(separator='\n', skip_empty=True, show_speaker=True)

Export the text content of all units as a single string.

Parameters:

Name Type Description Default
separator str

String used to separate units (default: newline).

'\n'
skip_empty bool

If True, skip units with empty or whitespace-only text.

True
show_speaker

If True, add speaker info.

True

Returns:

Type Description
str

Concatenated text of all units, separated by separator.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def export_text(self, separator: str = "\n", skip_empty: bool = True, show_speaker=True) -> str:
    """
    Export the text content of all units as a single string.

    Args:
        separator: String used to separate units (default: newline).
        skip_empty: If True, skip units with empty or whitespace-only text.
        show_speaker: If True, add speaker info.

    Returns:
        Concatenated text of all units, separated by `separator`.
    """
    def _text_line(unit: TimedTextUnit) -> str:
        if show_speaker and unit.speaker:
            return f"[{unit.speaker}] {unit.text}"
        return unit.text

    texts = [
        _text_line(unit) for unit in self.units
        if not skip_empty or unit.text.strip()
    ]
    return separator.join(texts)
extend(units)

Add multiple units to the end.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
245
246
247
248
def extend(self, units: List[TimedTextUnit]):
    """Add multiple units to the end."""
    for unit in units:
        self.append(unit)
filter_by_min_duration(min_duration_ms)

Return a new TimedText object containing only units with a minimum duration.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
311
312
313
314
315
316
317
318
319
def filter_by_min_duration(self, min_duration_ms: int) -> "TimedText":
    """
    Return a new TimedText object containing only units with a minimum duration.
    """
    filtered_units = [
        unit for unit in self.units
        if unit.duration_ms >= min_duration_ms
    ]
    return self._new_with_units(filtered_units, self.granularity)
is_segment_granularity()

Return True if granularity is SEGMENT.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
211
212
213
def is_segment_granularity(self) -> bool:
    """Return True if granularity is SEGMENT."""
    return self.granularity == Granularity.SEGMENT
is_word_granularity()

Return True if granularity is WORD.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
215
216
217
def is_word_granularity(self) -> bool:
    """Return True if granularity is WORD."""
    return self.granularity == Granularity.WORD
iter()

Unified iterator over the units of the correct granularity.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
339
340
341
342
343
def iter(self) -> Iterator[TimedTextUnit]:
    """
    Unified iterator over the units of the correct granularity.
    """
    return iter(self.units)
iter_segments()

Iterate over segment-level units.

Raises:

Type Description
ValueError

If granularity is not SEGMENT.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
345
346
347
348
349
350
351
352
353
354
def iter_segments(self) -> Iterator[TimedTextUnit]:
    """
    Iterate over segment-level units.

    Raises:
        ValueError: If granularity is not SEGMENT.
    """
    if not self.is_segment_granularity():
        raise ValueError("Cannot call iter_segments() on TimedText with WORD granularity.")
    return iter(self.segments)
iter_words()

Iterate over word-level units.

Raises:

Type Description
ValueError

If granularity is not WORD.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
356
357
358
359
360
361
362
363
364
365
def iter_words(self) -> Iterator[TimedTextUnit]:
    """
    Iterate over word-level units.

    Raises:
        ValueError: If granularity is not WORD.
    """
    if not self.is_word_granularity():
        raise ValueError("Cannot call iter_words() on TimedText with SEGMENT granularity.")
    return iter(self.words)
merge(items) classmethod

Merge a list of TimedText objects of the same granularity into a single TimedText object.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
@classmethod
def merge(cls, items: List["TimedText"]) -> "TimedText":
    """
    Merge a list of TimedText objects of the same granularity into a single TimedText object.
    """
    if not items:
        raise ValueError("No TimedText objects to merge.")
    granularity = items[0].granularity
    for item in items:
        if item.granularity != granularity:
            raise ValueError("Cannot merge TimedText objects of different granularities.")
    all_units: List[TimedTextUnit] = []
    for item in items:
        all_units.extend(item.units)

    # Use the classmethod to generate with units
    return cls._new_with_units(all_units, granularity)
model_post_init(__context)

After initialization, sort units by start time and normalize durations.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
198
199
200
201
202
203
204
def model_post_init(self, __context) -> None:
    """
    After initialization, sort units by start time and normalize durations.
    """
    self.sort_by_start()
    for unit in self.units:
        unit.normalize()
set_all_speakers(speaker)

Set the same speaker for all units.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
260
261
262
263
def set_all_speakers(self, speaker: str) -> None:
    """Set the same speaker for all units."""
    for unit in self.units:
        unit.set_speaker(speaker)
set_speaker(index, speaker)

Set speaker for a specific unit by index.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
254
255
256
257
258
def set_speaker(self, index: int, speaker: str) -> None:
    """Set speaker for a specific unit by index."""
    if not (0 <= index < len(self.units)):
        raise IndexError(f"Index {index} out of range for units.")
    self.units[index].set_speaker(speaker)
shift(offset_ms)

Shift all units by a given offset in milliseconds.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
265
266
267
268
def shift(self, offset_ms: int) -> None:
    """Shift all units by a given offset in milliseconds."""
    for i, unit in enumerate(self.units):
        self.units[i] = unit.shift_time(offset_ms)
slice(start_ms, end_ms)

Return a new TimedText object containing only units within [start_ms, end_ms]. Units must overlap with the interval to be included.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
300
301
302
303
304
305
306
307
308
309
def slice(self, start_ms: int, end_ms: int) -> "TimedText":
    """
    Return a new TimedText object containing only units within [start_ms, end_ms].
    Units must overlap with the interval to be included.
    """
    sliced_units = [
        unit for unit in self.units
        if unit.end_ms > start_ms and unit.start_ms < end_ms
    ]
    return self._new_with_units(sliced_units, self.granularity)
sort_by_start()

Sort units by start time.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
270
271
272
def sort_by_start(self) -> None:
    """Sort units by start time."""
    self.units.sort(key=lambda unit: unit.start_ms)
TimedTextUnit

Bases: BaseModel

Represents a timed unit with timestamps.

A fundamental building block for subtitle and transcript processing that associates text content with start/end times and optional metadata. Can represent either a segment (phrase/sentence) or a word.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class TimedTextUnit(BaseModel):
    """
    Represents a timed unit with timestamps.

    A fundamental building block for subtitle and transcript processing that
    associates text content with start/end times and optional metadata.
    Can represent either a segment (phrase/sentence) or a word.
    """
    text: str = Field(..., description="The text content")
    start_ms: int = Field(..., description="Start time in milliseconds")
    end_ms: int = Field(..., description="End time in milliseconds")
    speaker: Optional[str] = Field(None, description="Speaker identifier if available")
    index: Optional[int] = Field(None, description="Entry index or sequence number")
    granularity: Granularity
    confidence: Optional[float] = Field(None, description="Optional confidence score")

    @property
    def duration_ms(self) -> int:
        """Get duration in milliseconds."""
        return self.end_ms - self.start_ms

    @property
    def start_sec(self) -> float:
        """Get start time in seconds."""
        return self.start_ms / 1000

    @property
    def end_sec(self) -> float:
        """Get end time in seconds."""
        return self.end_ms / 1000

    @property
    def duration_sec(self) -> float:
        """Get duration in seconds."""
        return self.duration_ms / 1000

    def shift_time(self, offset_ms: int) -> "TimedTextUnit":
        """Create a new TimedUnit with timestamps shifted by offset."""
        return self.model_copy(
            update={
                "start_ms": self.start_ms + offset_ms,
                "end_ms": self.end_ms + offset_ms
            }
        )

    def overlaps_with(self, other: "TimedTextUnit") -> bool:
        """Check if this unit overlaps with another."""
        return (self.start_ms <= other.end_ms and 
                other.start_ms <= self.end_ms)

    def set_speaker(self, speaker: str) -> None:
        """Set the speaker label."""
        self.speaker = speaker

    def normalize(self) -> None:
        """Normalize the duration of the segment to be nonzero"""
        if self.start_ms == self.end_ms:
            self.end_ms = self.start_ms + 1 # minimum duration 

    @field_validator("start_ms", "end_ms")
    @classmethod
    def _validate_time_non_negative(cls, v: int) -> int:
        if v < 0:
            raise ValueError("start_ms and end_ms must be non-negative.")
        return v

    @field_validator("end_ms")
    @classmethod
    def _validate_positive_duration(cls, end_ms: int, info) -> int:
        start_ms = info.data.get("start_ms")
        if start_ms is not None and end_ms < start_ms:
            raise ValueError(
                f"end_ms ({end_ms}) must be greater than start_ms ({start_ms})."
            )
        return end_ms

    @field_validator("text")
    @classmethod
    def _validate_word_text(cls, v: str, info):
        granularity = info.data.get("granularity", "segment")
        if granularity == "word" and (" " in v.strip()):
            raise ValueError(
                "Text for a word-level TimedUnit cannot contain whitespace."
            )
        return v
confidence = Field(None, description='Optional confidence score') class-attribute instance-attribute
duration_ms property

Get duration in milliseconds.

duration_sec property

Get duration in seconds.

end_ms = Field(..., description='End time in milliseconds') class-attribute instance-attribute
end_sec property

Get end time in seconds.

granularity instance-attribute
index = Field(None, description='Entry index or sequence number') class-attribute instance-attribute
speaker = Field(None, description='Speaker identifier if available') class-attribute instance-attribute
start_ms = Field(..., description='Start time in milliseconds') class-attribute instance-attribute
start_sec property

Get start time in seconds.

text = Field(..., description='The text content') class-attribute instance-attribute
normalize()

Normalize the duration of the segment to be nonzero

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
77
78
79
80
def normalize(self) -> None:
    """Normalize the duration of the segment to be nonzero"""
    if self.start_ms == self.end_ms:
        self.end_ms = self.start_ms + 1 # minimum duration 
overlaps_with(other)

Check if this unit overlaps with another.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
68
69
70
71
def overlaps_with(self, other: "TimedTextUnit") -> bool:
    """Check if this unit overlaps with another."""
    return (self.start_ms <= other.end_ms and 
            other.start_ms <= self.end_ms)
set_speaker(speaker)

Set the speaker label.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
73
74
75
def set_speaker(self, speaker: str) -> None:
    """Set the speaker label."""
    self.speaker = speaker
shift_time(offset_ms)

Create a new TimedUnit with timestamps shifted by offset.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
59
60
61
62
63
64
65
66
def shift_time(self, offset_ms: int) -> "TimedTextUnit":
    """Create a new TimedUnit with timestamps shifted by offset."""
    return self.model_copy(
        update={
            "start_ms": self.start_ms + offset_ms,
            "end_ms": self.end_ms + offset_ms
        }
    )
timed_text

Module for handling timed text objects. For example, can be used subtitles like VTT and SRT.

This module provides classes and utilities for parsing, manipulating, and generating timed text objects useful in subtitle and transcript processing. It uses Pydantic for robust data validation and type safety.

Granularity

Bases: str, Enum

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
19
20
21
class Granularity(str, Enum):
    SEGMENT = "segment"
    WORD = "word"
SEGMENT = 'segment' class-attribute instance-attribute
WORD = 'word' class-attribute instance-attribute
TimedText

Bases: BaseModel

Represents a collection of timed text units of a single granularity.

Only one of segments or words is populated, determined by granularity. All units must match the declared granularity.

Notes
  • Start times must be non-decreasing (overlaps allowed for multiple speakers).
  • Negative start_ms or end_ms values are not allowed.
  • Durations must be strictly positive (>0 ms).
  • Mixed granularity is strictly prohibited.
Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
class TimedText(BaseModel):
    """
    Represents a collection of timed text units of a single granularity.

    Only one of `segments` or `words` is populated, determined by `granularity`.
    All units must match the declared granularity.

    Notes:
        - Start times must be non-decreasing (overlaps allowed for multiple speakers).
        - Negative start_ms or end_ms values are not allowed.
        - Durations must be strictly positive (>0 ms).
        - Mixed granularity is strictly prohibited.
    """

    granularity: Granularity = Field(..., description="Granularity type for all units.")
    segments: List[TimedTextUnit] = Field(default_factory=list, description="Phrase-level timed units")
    words: List[TimedTextUnit] = Field(default_factory=list, description="Word-level timed units")

    def __init__(
        self,
        *,
        granularity: Optional[Granularity] = None,
        segments: Optional[List[TimedTextUnit]] = None,
        words: Optional[List[TimedTextUnit]] = None,
        units: Optional[List[TimedTextUnit]] = None,
        **kwargs
    ):
        """
        Custom initializer for TimedText.
        If `units` is provided, granularity is inferred from the first unit unless explicitly set.
        If only `segments` or `words` is provided, granularity is set accordingly.
        If all are empty, granularity must be provided.
        """
        segments = segments or []
        words = words or []
        if units is not None:
            if units:
                inferred_granularity = units[0].granularity
                granularity = granularity or inferred_granularity
                if granularity == Granularity.SEGMENT:
                    segments = units
                    words = []
                elif granularity == Granularity.WORD:
                    words = units
                    segments = []
                else:
                    raise ValueError("Invalid granularity inferred from units.")
            else:
                if granularity is None:
                    raise ValueError("Must provide granularity for empty TimedText.")
        elif segments:
            granularity = granularity or Granularity.SEGMENT
            words = []
        elif words:
            granularity = granularity or Granularity.WORD
            segments = []
        elif granularity is None:
            raise ValueError("Must provide granularity for empty TimedText.")

        super().__init__(granularity=granularity, segments=segments, words=words, **kwargs)

    @model_validator(mode="after")
    def _validate_exclusive_granularity(self):
        """
        Validate that TimedText contains only units matching its granularity.
        Allows empty TimedText objects for prototyping and construction.
        Modular logic for segments and words.
        """
        granularity = self.granularity
        segments = self.segments
        words = self.words

        if granularity == Granularity.SEGMENT:
            if words:
                raise ValueError("TimedText with SEGMENT granularity must not have word units.")
            for unit in segments:
                if unit.granularity != Granularity.SEGMENT:
                    raise ValueError("All segment units must have granularity SEGMENT.")
        elif granularity == Granularity.WORD:
            if segments:
                raise ValueError("TimedText with WORD granularity must not have segment units.")
            for unit in words:
                if unit.granularity != Granularity.WORD:
                    raise ValueError("All word units must have granularity WORD.")
        else:
            raise ValueError("Invalid granularity type.")
        return self

    def model_post_init(self, __context) -> None:
        """
        After initialization, sort units by start time and normalize durations.
        """
        self.sort_by_start()
        for unit in self.units:
            unit.normalize()

    @property
    def units(self) -> List[TimedTextUnit]:
        """Return the list of units matching the granularity."""
        return self.segments if self.granularity == Granularity.SEGMENT else self.words

    def is_segment_granularity(self) -> bool:
        """Return True if granularity is SEGMENT."""
        return self.granularity == Granularity.SEGMENT

    def is_word_granularity(self) -> bool:
        """Return True if granularity is WORD."""
        return self.granularity == Granularity.WORD

    @property
    def start_ms(self) -> int:
        """Get the start time of the earliest unit."""
        return min(unit.start_ms for unit in self.units) if self.units else 0

    @property
    def end_ms(self) -> int:
        """Get the end time of the latest unit."""
        return max(unit.end_ms for unit in self.units) if self.units else 0

    @property
    def duration(self) -> int:
        """Get the total duration in milliseconds."""
        return self.end_ms - self.start_ms

    def __len__(self) -> int:
        """Return the number of units."""
        return len(self.units)

    def append(self, unit: TimedTextUnit):
        """Add a unit to the end."""
        if unit.granularity != self.granularity:
            raise ValueError(f"Cannot append unit with granularity {unit.granularity} "
                             "to TimedText of granularity {self.granularity}.")
        self.units.append(unit)

    def extend(self, units: List[TimedTextUnit]):
        """Add multiple units to the end."""
        for unit in units:
            self.append(unit)

    def clear(self):
        """Remove all units."""
        self.units.clear()

    def set_speaker(self, index: int, speaker: str) -> None:
        """Set speaker for a specific unit by index."""
        if not (0 <= index < len(self.units)):
            raise IndexError(f"Index {index} out of range for units.")
        self.units[index].set_speaker(speaker)

    def set_all_speakers(self, speaker: str) -> None:
        """Set the same speaker for all units."""
        for unit in self.units:
            unit.set_speaker(speaker)

    def shift(self, offset_ms: int) -> None:
        """Shift all units by a given offset in milliseconds."""
        for i, unit in enumerate(self.units):
            self.units[i] = unit.shift_time(offset_ms)

    def sort_by_start(self) -> None:
        """Sort units by start time."""
        self.units.sort(key=lambda unit: unit.start_ms)


    @classmethod
    def _new_with_units(
        cls, units: List[TimedTextUnit], granularity: Optional[Granularity] = None
    ) -> "TimedText":
        """
        Helper to create a new TimedText object with the given granularity and units.
        If granularity is not provided, it is inferred from the first unit.
        """
        if units:
            inferred_granularity = units[0].granularity
            granularity = granularity or inferred_granularity
            if granularity == Granularity.SEGMENT:
                return cls(granularity=granularity, segments=units, words=[])
            elif granularity == Granularity.WORD:
                return cls(granularity=granularity, segments=[], words=units)
            else:
                raise ValueError("Invalid granularity inferred from units.")
        else:
            if granularity is None:
                raise ValueError("Must provide granularity for empty TimedText.")
            if granularity in [Granularity.SEGMENT, Granularity.WORD]:
                return cls(granularity=granularity, segments=[], words=[])
            else:
                raise ValueError("Invalid granularity provided.")

    def slice(self, start_ms: int, end_ms: int) -> "TimedText":
        """
        Return a new TimedText object containing only units within [start_ms, end_ms].
        Units must overlap with the interval to be included.
        """
        sliced_units = [
            unit for unit in self.units
            if unit.end_ms > start_ms and unit.start_ms < end_ms
        ]
        return self._new_with_units(sliced_units, self.granularity)

    def filter_by_min_duration(self, min_duration_ms: int) -> "TimedText":
        """
        Return a new TimedText object containing only units with a minimum duration.
        """
        filtered_units = [
            unit for unit in self.units
            if unit.duration_ms >= min_duration_ms
        ]
        return self._new_with_units(filtered_units, self.granularity)

    @classmethod
    def merge(cls, items: List["TimedText"]) -> "TimedText":
        """
        Merge a list of TimedText objects of the same granularity into a single TimedText object.
        """
        if not items:
            raise ValueError("No TimedText objects to merge.")
        granularity = items[0].granularity
        for item in items:
            if item.granularity != granularity:
                raise ValueError("Cannot merge TimedText objects of different granularities.")
        all_units: List[TimedTextUnit] = []
        for item in items:
            all_units.extend(item.units)

        # Use the classmethod to generate with units
        return cls._new_with_units(all_units, granularity)

    def iter(self) -> Iterator[TimedTextUnit]:
        """
        Unified iterator over the units of the correct granularity.
        """
        return iter(self.units)

    def iter_segments(self) -> Iterator[TimedTextUnit]:
        """
        Iterate over segment-level units.

        Raises:
            ValueError: If granularity is not SEGMENT.
        """
        if not self.is_segment_granularity():
            raise ValueError("Cannot call iter_segments() on TimedText with WORD granularity.")
        return iter(self.segments)

    def iter_words(self) -> Iterator[TimedTextUnit]:
        """
        Iterate over word-level units.

        Raises:
            ValueError: If granularity is not WORD.
        """
        if not self.is_word_granularity():
            raise ValueError("Cannot call iter_words() on TimedText with SEGMENT granularity.")
        return iter(self.words)

    def export_text(self, separator: str = "\n", skip_empty: bool = True, show_speaker=True) -> str:
        """
        Export the text content of all units as a single string.

        Args:
            separator: String used to separate units (default: newline).
            skip_empty: If True, skip units with empty or whitespace-only text.
            show_speaker: If True, add speaker info.

        Returns:
            Concatenated text of all units, separated by `separator`.
        """
        def _text_line(unit: TimedTextUnit) -> str:
            if show_speaker and unit.speaker:
                return f"[{unit.speaker}] {unit.text}"
            return unit.text

        texts = [
            _text_line(unit) for unit in self.units
            if not skip_empty or unit.text.strip()
        ]
        return separator.join(texts)
duration property

Get the total duration in milliseconds.

end_ms property

Get the end time of the latest unit.

granularity = Field(..., description='Granularity type for all units.') class-attribute instance-attribute
segments = Field(default_factory=list, description='Phrase-level timed units') class-attribute instance-attribute
start_ms property

Get the start time of the earliest unit.

units property

Return the list of units matching the granularity.

words = Field(default_factory=list, description='Word-level timed units') class-attribute instance-attribute
__init__(*, granularity=None, segments=None, words=None, units=None, **kwargs)

Custom initializer for TimedText. If units is provided, granularity is inferred from the first unit unless explicitly set. If only segments or words is provided, granularity is set accordingly. If all are empty, granularity must be provided.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def __init__(
    self,
    *,
    granularity: Optional[Granularity] = None,
    segments: Optional[List[TimedTextUnit]] = None,
    words: Optional[List[TimedTextUnit]] = None,
    units: Optional[List[TimedTextUnit]] = None,
    **kwargs
):
    """
    Custom initializer for TimedText.
    If `units` is provided, granularity is inferred from the first unit unless explicitly set.
    If only `segments` or `words` is provided, granularity is set accordingly.
    If all are empty, granularity must be provided.
    """
    segments = segments or []
    words = words or []
    if units is not None:
        if units:
            inferred_granularity = units[0].granularity
            granularity = granularity or inferred_granularity
            if granularity == Granularity.SEGMENT:
                segments = units
                words = []
            elif granularity == Granularity.WORD:
                words = units
                segments = []
            else:
                raise ValueError("Invalid granularity inferred from units.")
        else:
            if granularity is None:
                raise ValueError("Must provide granularity for empty TimedText.")
    elif segments:
        granularity = granularity or Granularity.SEGMENT
        words = []
    elif words:
        granularity = granularity or Granularity.WORD
        segments = []
    elif granularity is None:
        raise ValueError("Must provide granularity for empty TimedText.")

    super().__init__(granularity=granularity, segments=segments, words=words, **kwargs)
__len__()

Return the number of units.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
234
235
236
def __len__(self) -> int:
    """Return the number of units."""
    return len(self.units)
append(unit)

Add a unit to the end.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
238
239
240
241
242
243
def append(self, unit: TimedTextUnit):
    """Add a unit to the end."""
    if unit.granularity != self.granularity:
        raise ValueError(f"Cannot append unit with granularity {unit.granularity} "
                         "to TimedText of granularity {self.granularity}.")
    self.units.append(unit)
clear()

Remove all units.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
250
251
252
def clear(self):
    """Remove all units."""
    self.units.clear()
export_text(separator='\n', skip_empty=True, show_speaker=True)

Export the text content of all units as a single string.

Parameters:

Name Type Description Default
separator str

String used to separate units (default: newline).

'\n'
skip_empty bool

If True, skip units with empty or whitespace-only text.

True
show_speaker

If True, add speaker info.

True

Returns:

Type Description
str

Concatenated text of all units, separated by separator.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def export_text(self, separator: str = "\n", skip_empty: bool = True, show_speaker=True) -> str:
    """
    Export the text content of all units as a single string.

    Args:
        separator: String used to separate units (default: newline).
        skip_empty: If True, skip units with empty or whitespace-only text.
        show_speaker: If True, add speaker info.

    Returns:
        Concatenated text of all units, separated by `separator`.
    """
    def _text_line(unit: TimedTextUnit) -> str:
        if show_speaker and unit.speaker:
            return f"[{unit.speaker}] {unit.text}"
        return unit.text

    texts = [
        _text_line(unit) for unit in self.units
        if not skip_empty or unit.text.strip()
    ]
    return separator.join(texts)
extend(units)

Add multiple units to the end.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
245
246
247
248
def extend(self, units: List[TimedTextUnit]):
    """Add multiple units to the end."""
    for unit in units:
        self.append(unit)
filter_by_min_duration(min_duration_ms)

Return a new TimedText object containing only units with a minimum duration.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
311
312
313
314
315
316
317
318
319
def filter_by_min_duration(self, min_duration_ms: int) -> "TimedText":
    """
    Return a new TimedText object containing only units with a minimum duration.
    """
    filtered_units = [
        unit for unit in self.units
        if unit.duration_ms >= min_duration_ms
    ]
    return self._new_with_units(filtered_units, self.granularity)
is_segment_granularity()

Return True if granularity is SEGMENT.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
211
212
213
def is_segment_granularity(self) -> bool:
    """Return True if granularity is SEGMENT."""
    return self.granularity == Granularity.SEGMENT
is_word_granularity()

Return True if granularity is WORD.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
215
216
217
def is_word_granularity(self) -> bool:
    """Return True if granularity is WORD."""
    return self.granularity == Granularity.WORD
iter()

Unified iterator over the units of the correct granularity.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
339
340
341
342
343
def iter(self) -> Iterator[TimedTextUnit]:
    """
    Unified iterator over the units of the correct granularity.
    """
    return iter(self.units)
iter_segments()

Iterate over segment-level units.

Raises:

Type Description
ValueError

If granularity is not SEGMENT.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
345
346
347
348
349
350
351
352
353
354
def iter_segments(self) -> Iterator[TimedTextUnit]:
    """
    Iterate over segment-level units.

    Raises:
        ValueError: If granularity is not SEGMENT.
    """
    if not self.is_segment_granularity():
        raise ValueError("Cannot call iter_segments() on TimedText with WORD granularity.")
    return iter(self.segments)
iter_words()

Iterate over word-level units.

Raises:

Type Description
ValueError

If granularity is not WORD.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
356
357
358
359
360
361
362
363
364
365
def iter_words(self) -> Iterator[TimedTextUnit]:
    """
    Iterate over word-level units.

    Raises:
        ValueError: If granularity is not WORD.
    """
    if not self.is_word_granularity():
        raise ValueError("Cannot call iter_words() on TimedText with SEGMENT granularity.")
    return iter(self.words)
merge(items) classmethod

Merge a list of TimedText objects of the same granularity into a single TimedText object.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
@classmethod
def merge(cls, items: List["TimedText"]) -> "TimedText":
    """
    Merge a list of TimedText objects of the same granularity into a single TimedText object.
    """
    if not items:
        raise ValueError("No TimedText objects to merge.")
    granularity = items[0].granularity
    for item in items:
        if item.granularity != granularity:
            raise ValueError("Cannot merge TimedText objects of different granularities.")
    all_units: List[TimedTextUnit] = []
    for item in items:
        all_units.extend(item.units)

    # Use the classmethod to generate with units
    return cls._new_with_units(all_units, granularity)
model_post_init(__context)

After initialization, sort units by start time and normalize durations.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
198
199
200
201
202
203
204
def model_post_init(self, __context) -> None:
    """
    After initialization, sort units by start time and normalize durations.
    """
    self.sort_by_start()
    for unit in self.units:
        unit.normalize()
set_all_speakers(speaker)

Set the same speaker for all units.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
260
261
262
263
def set_all_speakers(self, speaker: str) -> None:
    """Set the same speaker for all units."""
    for unit in self.units:
        unit.set_speaker(speaker)
set_speaker(index, speaker)

Set speaker for a specific unit by index.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
254
255
256
257
258
def set_speaker(self, index: int, speaker: str) -> None:
    """Set speaker for a specific unit by index."""
    if not (0 <= index < len(self.units)):
        raise IndexError(f"Index {index} out of range for units.")
    self.units[index].set_speaker(speaker)
shift(offset_ms)

Shift all units by a given offset in milliseconds.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
265
266
267
268
def shift(self, offset_ms: int) -> None:
    """Shift all units by a given offset in milliseconds."""
    for i, unit in enumerate(self.units):
        self.units[i] = unit.shift_time(offset_ms)
slice(start_ms, end_ms)

Return a new TimedText object containing only units within [start_ms, end_ms]. Units must overlap with the interval to be included.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
300
301
302
303
304
305
306
307
308
309
def slice(self, start_ms: int, end_ms: int) -> "TimedText":
    """
    Return a new TimedText object containing only units within [start_ms, end_ms].
    Units must overlap with the interval to be included.
    """
    sliced_units = [
        unit for unit in self.units
        if unit.end_ms > start_ms and unit.start_ms < end_ms
    ]
    return self._new_with_units(sliced_units, self.granularity)
sort_by_start()

Sort units by start time.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
270
271
272
def sort_by_start(self) -> None:
    """Sort units by start time."""
    self.units.sort(key=lambda unit: unit.start_ms)
TimedTextUnit

Bases: BaseModel

Represents a timed unit with timestamps.

A fundamental building block for subtitle and transcript processing that associates text content with start/end times and optional metadata. Can represent either a segment (phrase/sentence) or a word.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class TimedTextUnit(BaseModel):
    """
    Represents a timed unit with timestamps.

    A fundamental building block for subtitle and transcript processing that
    associates text content with start/end times and optional metadata.
    Can represent either a segment (phrase/sentence) or a word.
    """
    text: str = Field(..., description="The text content")
    start_ms: int = Field(..., description="Start time in milliseconds")
    end_ms: int = Field(..., description="End time in milliseconds")
    speaker: Optional[str] = Field(None, description="Speaker identifier if available")
    index: Optional[int] = Field(None, description="Entry index or sequence number")
    granularity: Granularity
    confidence: Optional[float] = Field(None, description="Optional confidence score")

    @property
    def duration_ms(self) -> int:
        """Get duration in milliseconds."""
        return self.end_ms - self.start_ms

    @property
    def start_sec(self) -> float:
        """Get start time in seconds."""
        return self.start_ms / 1000

    @property
    def end_sec(self) -> float:
        """Get end time in seconds."""
        return self.end_ms / 1000

    @property
    def duration_sec(self) -> float:
        """Get duration in seconds."""
        return self.duration_ms / 1000

    def shift_time(self, offset_ms: int) -> "TimedTextUnit":
        """Create a new TimedUnit with timestamps shifted by offset."""
        return self.model_copy(
            update={
                "start_ms": self.start_ms + offset_ms,
                "end_ms": self.end_ms + offset_ms
            }
        )

    def overlaps_with(self, other: "TimedTextUnit") -> bool:
        """Check if this unit overlaps with another."""
        return (self.start_ms <= other.end_ms and 
                other.start_ms <= self.end_ms)

    def set_speaker(self, speaker: str) -> None:
        """Set the speaker label."""
        self.speaker = speaker

    def normalize(self) -> None:
        """Normalize the duration of the segment to be nonzero"""
        if self.start_ms == self.end_ms:
            self.end_ms = self.start_ms + 1 # minimum duration 

    @field_validator("start_ms", "end_ms")
    @classmethod
    def _validate_time_non_negative(cls, v: int) -> int:
        if v < 0:
            raise ValueError("start_ms and end_ms must be non-negative.")
        return v

    @field_validator("end_ms")
    @classmethod
    def _validate_positive_duration(cls, end_ms: int, info) -> int:
        start_ms = info.data.get("start_ms")
        if start_ms is not None and end_ms < start_ms:
            raise ValueError(
                f"end_ms ({end_ms}) must be greater than start_ms ({start_ms})."
            )
        return end_ms

    @field_validator("text")
    @classmethod
    def _validate_word_text(cls, v: str, info):
        granularity = info.data.get("granularity", "segment")
        if granularity == "word" and (" " in v.strip()):
            raise ValueError(
                "Text for a word-level TimedUnit cannot contain whitespace."
            )
        return v
confidence = Field(None, description='Optional confidence score') class-attribute instance-attribute
duration_ms property

Get duration in milliseconds.

duration_sec property

Get duration in seconds.

end_ms = Field(..., description='End time in milliseconds') class-attribute instance-attribute
end_sec property

Get end time in seconds.

granularity instance-attribute
index = Field(None, description='Entry index or sequence number') class-attribute instance-attribute
speaker = Field(None, description='Speaker identifier if available') class-attribute instance-attribute
start_ms = Field(..., description='Start time in milliseconds') class-attribute instance-attribute
start_sec property

Get start time in seconds.

text = Field(..., description='The text content') class-attribute instance-attribute
normalize()

Normalize the duration of the segment to be nonzero

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
77
78
79
80
def normalize(self) -> None:
    """Normalize the duration of the segment to be nonzero"""
    if self.start_ms == self.end_ms:
        self.end_ms = self.start_ms + 1 # minimum duration 
overlaps_with(other)

Check if this unit overlaps with another.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
68
69
70
71
def overlaps_with(self, other: "TimedTextUnit") -> bool:
    """Check if this unit overlaps with another."""
    return (self.start_ms <= other.end_ms and 
            other.start_ms <= self.end_ms)
set_speaker(speaker)

Set the speaker label.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
73
74
75
def set_speaker(self, speaker: str) -> None:
    """Set the speaker label."""
    self.speaker = speaker
shift_time(offset_ms)

Create a new TimedUnit with timestamps shifted by offset.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
59
60
61
62
63
64
65
66
def shift_time(self, offset_ms: int) -> "TimedTextUnit":
    """Create a new TimedUnit with timestamps shifted by offset."""
    return self.model_copy(
        update={
            "start_ms": self.start_ms + offset_ms,
            "end_ms": self.end_ms + offset_ms
        }
    )

transcription

__all__ = ['patch_whisper_options', 'DiarizationChunker', 'TimedText', 'TextSegmentBuilder', 'TimedTextUnit', 'Granularity', 'TranscriptionService', 'TranscriptionServiceFactory'] module-attribute
DiarizationChunker

Class for chunking diarization results into processing units based on configurable duration targets.

Source code in src/tnh_scholar/audio_processing/diarization/chunker.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
class DiarizationChunker:
    """
    Class for chunking diarization results into processing units
    based on configurable duration targets.
    """

    def __init__(self, **config_options):
        """Initialize chunker with additional config_options."""
        self.config = ChunkConfig()

        self._handle_config_options(config_options)


    def extract_contiguous_chunks(self, segments: List[DiarizedSegment]) -> List[DiarizationChunk]:
        """
        Split diarization segments into contiguous chunks of
        approximately target_duration, without splitting on speaker changes.

        Args:
            segments: List of speaker segments from diarization

        Returns:
            List[Chunk]: Flat list of contiguous chunks
        """
        if not segments:
            return []

        extractor = self._ChunkExtractor(self.config, split_on_speaker_change=False)
        return extractor.extract(segments)

    class _ChunkExtractor:
        def __init__(self, config: ChunkConfig, split_on_speaker_change: bool = True):
            self.config = config
            self.split_on_speaker_change = split_on_speaker_change
            self.gap_threshold = self.config.gap_threshold
            self.spacing = self.config.gap_spacing_time
            self.chunks: List[DiarizationChunk] = []
            self.current_chunk_segments: List[DiarizedSegment] = []
            self.chunk_start: int = 0
            self.current_speaker = ""
            self.accumulated_time: int = 0

        @property
        def last_segment(self):
            return self.current_chunk_segments[-1] if self.current_chunk_segments else None

        def extract(self, segments: List[DiarizedSegment]) -> List[DiarizationChunk]:
            if not segments:
                return []

            self.chunk_start = int(segments[0].start)
            self.current_speaker = segments[0].speaker
            for segment in segments:
                self._check_segment_duration(segment)  
                self._process_segment(segment)

            self._finalize_last_chunk()
            return self.chunks

        def _process_segment(self, segment: DiarizedSegment):
            if self._should_split(segment):
                self._finalize_current_chunk(segment)
                self.chunk_start = int(segment.start)                
            self._add_segment(segment)

        def _add_segment(self, segment: DiarizedSegment):
            gap_time =  self._gap_time(segment)
            if gap_time > self.gap_threshold:
                segment.gap_before = True
                segment.spacing_time = self.spacing
                self.accumulated_time += int(segment.duration) + self.spacing
            else:
                segment.gap_before = False
                segment.spacing_time = max(gap_time, 0)
                self.accumulated_time += int(segment.duration) + gap_time
            self.current_chunk_segments.append(segment)
            self.current_speaker = segment.speaker

        def _gap_time(self, segment) -> int:
            if self.last_segment is None:
                # If no last_segment, this is first segment, so no gap.
                return 0 
            else:
                return segment.start - self.last_segment.end


        def _should_split(self, segment: DiarizedSegment) -> bool:
            gap_time = self._gap_time(segment)
            interval_time = gap_time if gap_time < self.gap_threshold else self.spacing
            accumulated_time = self.accumulated_time + interval_time + segment.duration
            return accumulated_time >= self.config.target_duration 

        def _finalize_current_chunk(self, next_segment: Optional[DiarizedSegment]):
            if self.current_chunk_segments:
                assert self.last_segment is not None
                self.chunks.append(
                    DiarizationChunk(
                        start_time=int(self.chunk_start),
                        end_time=int(self.last_segment.end), 
                        segments=self.current_chunk_segments.copy(),
                        audio=None,
                        accumulated_time=self.accumulated_time
                    )
                )
                self._reset_chunk_state(next_segment)             

        def _reset_chunk_state(self, next_segment):
            self.current_chunk_segments = []
            self.accumulated_time = 0
            if self.split_on_speaker_change and next_segment:
                    self.current_speaker = next_segment.speaker

        def _finalize_last_chunk(self):
            if self.current_chunk_segments:
                self._handle_final_segments()

        def _check_segment_duration(self, segment: DiarizedSegment) -> None:
            """Check if segment exceeds target duration and issue warning if needed."""
            if segment.duration > self.config.target_duration:
                logger.warning(f"Found segment longer than "
                            f"target duration: {segment.duration_sec:.0f}s")

        def _handle_final_segments(self) -> None:
            """Append final segments to last chunk if below min duration."""
            approx_remaining_time = sum(segment.duration for segment in self.current_chunk_segments)
            final_time = self.accumulated_time + approx_remaining_time
            min_time = self.config.min_duration

            if final_time < min_time and self.chunks:
               self._merge_to_last_chunk()
            else:
                # Create standalone chunk
                self._finalize_current_chunk(next_segment=None)

        def _merge_to_last_chunk(self):
            """Merge segments to the last chunk processed. self.chunks cannot be empty."""
            assert self.chunks
            self.chunks[-1].segments.extend(self.current_chunk_segments)
            self.chunks[-1].end_time = int(self.current_chunk_segments[-1].end)
            self.chunks[-1].accumulated_time += self.accumulated_time



    def _handle_config_options(self, config_options: Dict[str, Any]) -> None:
        """
        Handles additional configuration options, 
        logging a warning for unrecognized keys.
        """
        for key, value in config_options.items():
            if hasattr(self.config, key):
                setattr(self.config, key, value)
            else:
                logger.warning(f"Unrecognized configuration option: {key}")
config = ChunkConfig() instance-attribute
__init__(**config_options)

Initialize chunker with additional config_options.

Source code in src/tnh_scholar/audio_processing/diarization/chunker.py
20
21
22
23
24
def __init__(self, **config_options):
    """Initialize chunker with additional config_options."""
    self.config = ChunkConfig()

    self._handle_config_options(config_options)
extract_contiguous_chunks(segments)

Split diarization segments into contiguous chunks of approximately target_duration, without splitting on speaker changes.

Parameters:

Name Type Description Default
segments List[DiarizedSegment]

List of speaker segments from diarization

required

Returns:

Type Description
List[DiarizationChunk]

List[Chunk]: Flat list of contiguous chunks

Source code in src/tnh_scholar/audio_processing/diarization/chunker.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def extract_contiguous_chunks(self, segments: List[DiarizedSegment]) -> List[DiarizationChunk]:
    """
    Split diarization segments into contiguous chunks of
    approximately target_duration, without splitting on speaker changes.

    Args:
        segments: List of speaker segments from diarization

    Returns:
        List[Chunk]: Flat list of contiguous chunks
    """
    if not segments:
        return []

    extractor = self._ChunkExtractor(self.config, split_on_speaker_change=False)
    return extractor.extract(segments)
Granularity

Bases: str, Enum

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
19
20
21
class Granularity(str, Enum):
    SEGMENT = "segment"
    WORD = "word"
SEGMENT = 'segment' class-attribute instance-attribute
WORD = 'word' class-attribute instance-attribute
TextSegmentBuilder
Source code in src/tnh_scholar/audio_processing/transcription/text_segment_builder.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class TextSegmentBuilder:
    def __init__(
        self,
        *,
        max_duration_ms: Optional[int] = None, # milliseconds
        target_characters: Optional[int] = None,
        avoid_orphans: bool = True,
        max_gap_duration_ms: Optional[int] = None, # milliseconds
        ignore_speaker: bool = True,
    ):
        self.max_duration = max_duration_ms
        self.target_characters = target_characters
        self.avoid_orphans = avoid_orphans
        self.max_gap_duration = max_gap_duration_ms
        self.ignore_speaker = ignore_speaker

        self.segments: List[TimedTextUnit] = []
        self.current_words: List[TimedTextUnit] = []
        self.current_characters = 0

    def create_segments(self, timed_text: TimedText) -> TimedText:
        # Validate
        if not timed_text.words:
            raise ValueError(
                "TimedText object must have word-level units to build segments."
                )

        for unit in timed_text.words:
            if unit.granularity != Granularity.WORD:
                raise ValueError(f"Expected WORD units, got {unit.granularity}")

        # Process
        for word in timed_text.words:
            if self._should_start_new_segment(word):
                self._flush_current_words()
            self._add_word(word)

        self._flush_current_words()  # Final flush
        return TimedText(segments=self.segments, granularity=Granularity.SEGMENT)

    def _add_word(self, word: TimedTextUnit):
        if self.current_words:
            self.current_characters += 1  # space before the new word
        self.current_characters += len(word.text)
        self.current_words.append(word)


    def _should_start_new_segment(self, word: TimedTextUnit) -> bool:
        if not self.current_words:
            return False

        # Speaker change
        last_word = self.current_words[-1]
        if not self.ignore_speaker and (word.speaker != last_word.speaker):
            return True

        # Significant pause
        if self.max_gap_duration is not None:
            pause = word.start_ms - last_word.end_ms
            if pause > self.max_gap_duration:
                return True

        # End punctuation
        if last_word.text and self._is_punctuation_word(last_word.text):
            return True

        # Max duration
        if self.max_duration is not None:
            duration = word.end_ms - self.current_words[0].start_ms
            if duration > self.max_duration:
                return True

        # Target character count
        if self.target_characters is not None:
            total_chars = self.current_characters + len(word.text) + 1
            if total_chars > self.target_characters:
                return True

        return False

    def _flush_current_words(self):
        if not self.current_words:
            return

        segment_text = " ".join(word.text for word in self.current_words)
        segment = TimedTextUnit(
            text=segment_text,
            start_ms=self.current_words[0].start_ms,
            end_ms=self.current_words[-1].end_ms,
            granularity=Granularity.SEGMENT,
            speaker=None if self.ignore_speaker else self._find_speaker(),
            confidence=None,
            index=None,
        )
        self.segments.append(segment)
        self.current_words = []
        self.current_characters = 0

    def _find_speaker(self) -> Optional[str]:
        """
        Only called when ignore_speakers is False; 
        in this case we always split on speaker. 
        So only one speaker is expected. 
        """
        speakers = {word.speaker for word in self.current_words}
        assert len(speakers) == 1, "Inconsistent speakers in segment"
        return speakers.pop()

    def _is_punctuation_word(self, word_text: str) -> bool:
        """
        Check if a word ending in punctuation should trigger a new segment,
        excluding common abbreviations.
        """
        if not word_text:
            return False
        return word_text[-1] in ".!?" and word_text.lower() not in COMMON_ABBREVIATIONS


    def build_segments(
        self,
        *,
        target_duration: Optional[int] = None,
        target_characters: Optional[int] = None,
        avoid_orphans: Optional[bool] = True,
        max_gap_duration: Optional[int] = None,
        ignore_speaker: bool = False,
    ) -> None:
        """
        Build or rebuild `segments` from the contents of `words`.

        Args:
            target_duration: Maximum desired segment duration in milliseconds.
            target_characters: Maximum desired character length of a segment.
            speaker_split: Whether to start a new segment when the speaker changes.

        Note:
            This is a stub.  Concrete algorithms will be implemented later.

        Raises:
            NotImplementedError: Always, until implemented.
        """
        raise NotImplementedError("build_segments is not yet implemented.")
avoid_orphans = avoid_orphans instance-attribute
current_characters = 0 instance-attribute
current_words = [] instance-attribute
ignore_speaker = ignore_speaker instance-attribute
max_duration = max_duration_ms instance-attribute
max_gap_duration = max_gap_duration_ms instance-attribute
segments = [] instance-attribute
target_characters = target_characters instance-attribute
__init__(*, max_duration_ms=None, target_characters=None, avoid_orphans=True, max_gap_duration_ms=None, ignore_speaker=True)
Source code in src/tnh_scholar/audio_processing/transcription/text_segment_builder.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def __init__(
    self,
    *,
    max_duration_ms: Optional[int] = None, # milliseconds
    target_characters: Optional[int] = None,
    avoid_orphans: bool = True,
    max_gap_duration_ms: Optional[int] = None, # milliseconds
    ignore_speaker: bool = True,
):
    self.max_duration = max_duration_ms
    self.target_characters = target_characters
    self.avoid_orphans = avoid_orphans
    self.max_gap_duration = max_gap_duration_ms
    self.ignore_speaker = ignore_speaker

    self.segments: List[TimedTextUnit] = []
    self.current_words: List[TimedTextUnit] = []
    self.current_characters = 0
build_segments(*, target_duration=None, target_characters=None, avoid_orphans=True, max_gap_duration=None, ignore_speaker=False)

Build or rebuild segments from the contents of words.

Parameters:

Name Type Description Default
target_duration Optional[int]

Maximum desired segment duration in milliseconds.

None
target_characters Optional[int]

Maximum desired character length of a segment.

None
speaker_split

Whether to start a new segment when the speaker changes.

required
Note

This is a stub. Concrete algorithms will be implemented later.

Raises:

Type Description
NotImplementedError

Always, until implemented.

Source code in src/tnh_scholar/audio_processing/transcription/text_segment_builder.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def build_segments(
    self,
    *,
    target_duration: Optional[int] = None,
    target_characters: Optional[int] = None,
    avoid_orphans: Optional[bool] = True,
    max_gap_duration: Optional[int] = None,
    ignore_speaker: bool = False,
) -> None:
    """
    Build or rebuild `segments` from the contents of `words`.

    Args:
        target_duration: Maximum desired segment duration in milliseconds.
        target_characters: Maximum desired character length of a segment.
        speaker_split: Whether to start a new segment when the speaker changes.

    Note:
        This is a stub.  Concrete algorithms will be implemented later.

    Raises:
        NotImplementedError: Always, until implemented.
    """
    raise NotImplementedError("build_segments is not yet implemented.")
create_segments(timed_text)
Source code in src/tnh_scholar/audio_processing/transcription/text_segment_builder.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def create_segments(self, timed_text: TimedText) -> TimedText:
    # Validate
    if not timed_text.words:
        raise ValueError(
            "TimedText object must have word-level units to build segments."
            )

    for unit in timed_text.words:
        if unit.granularity != Granularity.WORD:
            raise ValueError(f"Expected WORD units, got {unit.granularity}")

    # Process
    for word in timed_text.words:
        if self._should_start_new_segment(word):
            self._flush_current_words()
        self._add_word(word)

    self._flush_current_words()  # Final flush
    return TimedText(segments=self.segments, granularity=Granularity.SEGMENT)
TimedText

Bases: BaseModel

Represents a collection of timed text units of a single granularity.

Only one of segments or words is populated, determined by granularity. All units must match the declared granularity.

Notes
  • Start times must be non-decreasing (overlaps allowed for multiple speakers).
  • Negative start_ms or end_ms values are not allowed.
  • Durations must be strictly positive (>0 ms).
  • Mixed granularity is strictly prohibited.
Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
class TimedText(BaseModel):
    """
    Represents a collection of timed text units of a single granularity.

    Only one of `segments` or `words` is populated, determined by `granularity`.
    All units must match the declared granularity.

    Notes:
        - Start times must be non-decreasing (overlaps allowed for multiple speakers).
        - Negative start_ms or end_ms values are not allowed.
        - Durations must be strictly positive (>0 ms).
        - Mixed granularity is strictly prohibited.
    """

    granularity: Granularity = Field(..., description="Granularity type for all units.")
    segments: List[TimedTextUnit] = Field(default_factory=list, description="Phrase-level timed units")
    words: List[TimedTextUnit] = Field(default_factory=list, description="Word-level timed units")

    def __init__(
        self,
        *,
        granularity: Optional[Granularity] = None,
        segments: Optional[List[TimedTextUnit]] = None,
        words: Optional[List[TimedTextUnit]] = None,
        units: Optional[List[TimedTextUnit]] = None,
        **kwargs
    ):
        """
        Custom initializer for TimedText.
        If `units` is provided, granularity is inferred from the first unit unless explicitly set.
        If only `segments` or `words` is provided, granularity is set accordingly.
        If all are empty, granularity must be provided.
        """
        segments = segments or []
        words = words or []
        if units is not None:
            if units:
                inferred_granularity = units[0].granularity
                granularity = granularity or inferred_granularity
                if granularity == Granularity.SEGMENT:
                    segments = units
                    words = []
                elif granularity == Granularity.WORD:
                    words = units
                    segments = []
                else:
                    raise ValueError("Invalid granularity inferred from units.")
            else:
                if granularity is None:
                    raise ValueError("Must provide granularity for empty TimedText.")
        elif segments:
            granularity = granularity or Granularity.SEGMENT
            words = []
        elif words:
            granularity = granularity or Granularity.WORD
            segments = []
        elif granularity is None:
            raise ValueError("Must provide granularity for empty TimedText.")

        super().__init__(granularity=granularity, segments=segments, words=words, **kwargs)

    @model_validator(mode="after")
    def _validate_exclusive_granularity(self):
        """
        Validate that TimedText contains only units matching its granularity.
        Allows empty TimedText objects for prototyping and construction.
        Modular logic for segments and words.
        """
        granularity = self.granularity
        segments = self.segments
        words = self.words

        if granularity == Granularity.SEGMENT:
            if words:
                raise ValueError("TimedText with SEGMENT granularity must not have word units.")
            for unit in segments:
                if unit.granularity != Granularity.SEGMENT:
                    raise ValueError("All segment units must have granularity SEGMENT.")
        elif granularity == Granularity.WORD:
            if segments:
                raise ValueError("TimedText with WORD granularity must not have segment units.")
            for unit in words:
                if unit.granularity != Granularity.WORD:
                    raise ValueError("All word units must have granularity WORD.")
        else:
            raise ValueError("Invalid granularity type.")
        return self

    def model_post_init(self, __context) -> None:
        """
        After initialization, sort units by start time and normalize durations.
        """
        self.sort_by_start()
        for unit in self.units:
            unit.normalize()

    @property
    def units(self) -> List[TimedTextUnit]:
        """Return the list of units matching the granularity."""
        return self.segments if self.granularity == Granularity.SEGMENT else self.words

    def is_segment_granularity(self) -> bool:
        """Return True if granularity is SEGMENT."""
        return self.granularity == Granularity.SEGMENT

    def is_word_granularity(self) -> bool:
        """Return True if granularity is WORD."""
        return self.granularity == Granularity.WORD

    @property
    def start_ms(self) -> int:
        """Get the start time of the earliest unit."""
        return min(unit.start_ms for unit in self.units) if self.units else 0

    @property
    def end_ms(self) -> int:
        """Get the end time of the latest unit."""
        return max(unit.end_ms for unit in self.units) if self.units else 0

    @property
    def duration(self) -> int:
        """Get the total duration in milliseconds."""
        return self.end_ms - self.start_ms

    def __len__(self) -> int:
        """Return the number of units."""
        return len(self.units)

    def append(self, unit: TimedTextUnit):
        """Add a unit to the end."""
        if unit.granularity != self.granularity:
            raise ValueError(f"Cannot append unit with granularity {unit.granularity} "
                             "to TimedText of granularity {self.granularity}.")
        self.units.append(unit)

    def extend(self, units: List[TimedTextUnit]):
        """Add multiple units to the end."""
        for unit in units:
            self.append(unit)

    def clear(self):
        """Remove all units."""
        self.units.clear()

    def set_speaker(self, index: int, speaker: str) -> None:
        """Set speaker for a specific unit by index."""
        if not (0 <= index < len(self.units)):
            raise IndexError(f"Index {index} out of range for units.")
        self.units[index].set_speaker(speaker)

    def set_all_speakers(self, speaker: str) -> None:
        """Set the same speaker for all units."""
        for unit in self.units:
            unit.set_speaker(speaker)

    def shift(self, offset_ms: int) -> None:
        """Shift all units by a given offset in milliseconds."""
        for i, unit in enumerate(self.units):
            self.units[i] = unit.shift_time(offset_ms)

    def sort_by_start(self) -> None:
        """Sort units by start time."""
        self.units.sort(key=lambda unit: unit.start_ms)


    @classmethod
    def _new_with_units(
        cls, units: List[TimedTextUnit], granularity: Optional[Granularity] = None
    ) -> "TimedText":
        """
        Helper to create a new TimedText object with the given granularity and units.
        If granularity is not provided, it is inferred from the first unit.
        """
        if units:
            inferred_granularity = units[0].granularity
            granularity = granularity or inferred_granularity
            if granularity == Granularity.SEGMENT:
                return cls(granularity=granularity, segments=units, words=[])
            elif granularity == Granularity.WORD:
                return cls(granularity=granularity, segments=[], words=units)
            else:
                raise ValueError("Invalid granularity inferred from units.")
        else:
            if granularity is None:
                raise ValueError("Must provide granularity for empty TimedText.")
            if granularity in [Granularity.SEGMENT, Granularity.WORD]:
                return cls(granularity=granularity, segments=[], words=[])
            else:
                raise ValueError("Invalid granularity provided.")

    def slice(self, start_ms: int, end_ms: int) -> "TimedText":
        """
        Return a new TimedText object containing only units within [start_ms, end_ms].
        Units must overlap with the interval to be included.
        """
        sliced_units = [
            unit for unit in self.units
            if unit.end_ms > start_ms and unit.start_ms < end_ms
        ]
        return self._new_with_units(sliced_units, self.granularity)

    def filter_by_min_duration(self, min_duration_ms: int) -> "TimedText":
        """
        Return a new TimedText object containing only units with a minimum duration.
        """
        filtered_units = [
            unit for unit in self.units
            if unit.duration_ms >= min_duration_ms
        ]
        return self._new_with_units(filtered_units, self.granularity)

    @classmethod
    def merge(cls, items: List["TimedText"]) -> "TimedText":
        """
        Merge a list of TimedText objects of the same granularity into a single TimedText object.
        """
        if not items:
            raise ValueError("No TimedText objects to merge.")
        granularity = items[0].granularity
        for item in items:
            if item.granularity != granularity:
                raise ValueError("Cannot merge TimedText objects of different granularities.")
        all_units: List[TimedTextUnit] = []
        for item in items:
            all_units.extend(item.units)

        # Use the classmethod to generate with units
        return cls._new_with_units(all_units, granularity)

    def iter(self) -> Iterator[TimedTextUnit]:
        """
        Unified iterator over the units of the correct granularity.
        """
        return iter(self.units)

    def iter_segments(self) -> Iterator[TimedTextUnit]:
        """
        Iterate over segment-level units.

        Raises:
            ValueError: If granularity is not SEGMENT.
        """
        if not self.is_segment_granularity():
            raise ValueError("Cannot call iter_segments() on TimedText with WORD granularity.")
        return iter(self.segments)

    def iter_words(self) -> Iterator[TimedTextUnit]:
        """
        Iterate over word-level units.

        Raises:
            ValueError: If granularity is not WORD.
        """
        if not self.is_word_granularity():
            raise ValueError("Cannot call iter_words() on TimedText with SEGMENT granularity.")
        return iter(self.words)

    def export_text(self, separator: str = "\n", skip_empty: bool = True, show_speaker=True) -> str:
        """
        Export the text content of all units as a single string.

        Args:
            separator: String used to separate units (default: newline).
            skip_empty: If True, skip units with empty or whitespace-only text.
            show_speaker: If True, add speaker info.

        Returns:
            Concatenated text of all units, separated by `separator`.
        """
        def _text_line(unit: TimedTextUnit) -> str:
            if show_speaker and unit.speaker:
                return f"[{unit.speaker}] {unit.text}"
            return unit.text

        texts = [
            _text_line(unit) for unit in self.units
            if not skip_empty or unit.text.strip()
        ]
        return separator.join(texts)
duration property

Get the total duration in milliseconds.

end_ms property

Get the end time of the latest unit.

granularity = Field(..., description='Granularity type for all units.') class-attribute instance-attribute
segments = Field(default_factory=list, description='Phrase-level timed units') class-attribute instance-attribute
start_ms property

Get the start time of the earliest unit.

units property

Return the list of units matching the granularity.

words = Field(default_factory=list, description='Word-level timed units') class-attribute instance-attribute
__init__(*, granularity=None, segments=None, words=None, units=None, **kwargs)

Custom initializer for TimedText. If units is provided, granularity is inferred from the first unit unless explicitly set. If only segments or words is provided, granularity is set accordingly. If all are empty, granularity must be provided.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def __init__(
    self,
    *,
    granularity: Optional[Granularity] = None,
    segments: Optional[List[TimedTextUnit]] = None,
    words: Optional[List[TimedTextUnit]] = None,
    units: Optional[List[TimedTextUnit]] = None,
    **kwargs
):
    """
    Custom initializer for TimedText.
    If `units` is provided, granularity is inferred from the first unit unless explicitly set.
    If only `segments` or `words` is provided, granularity is set accordingly.
    If all are empty, granularity must be provided.
    """
    segments = segments or []
    words = words or []
    if units is not None:
        if units:
            inferred_granularity = units[0].granularity
            granularity = granularity or inferred_granularity
            if granularity == Granularity.SEGMENT:
                segments = units
                words = []
            elif granularity == Granularity.WORD:
                words = units
                segments = []
            else:
                raise ValueError("Invalid granularity inferred from units.")
        else:
            if granularity is None:
                raise ValueError("Must provide granularity for empty TimedText.")
    elif segments:
        granularity = granularity or Granularity.SEGMENT
        words = []
    elif words:
        granularity = granularity or Granularity.WORD
        segments = []
    elif granularity is None:
        raise ValueError("Must provide granularity for empty TimedText.")

    super().__init__(granularity=granularity, segments=segments, words=words, **kwargs)
__len__()

Return the number of units.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
234
235
236
def __len__(self) -> int:
    """Return the number of units."""
    return len(self.units)
append(unit)

Add a unit to the end.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
238
239
240
241
242
243
def append(self, unit: TimedTextUnit):
    """Add a unit to the end."""
    if unit.granularity != self.granularity:
        raise ValueError(f"Cannot append unit with granularity {unit.granularity} "
                         "to TimedText of granularity {self.granularity}.")
    self.units.append(unit)
clear()

Remove all units.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
250
251
252
def clear(self):
    """Remove all units."""
    self.units.clear()
export_text(separator='\n', skip_empty=True, show_speaker=True)

Export the text content of all units as a single string.

Parameters:

Name Type Description Default
separator str

String used to separate units (default: newline).

'\n'
skip_empty bool

If True, skip units with empty or whitespace-only text.

True
show_speaker

If True, add speaker info.

True

Returns:

Type Description
str

Concatenated text of all units, separated by separator.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def export_text(self, separator: str = "\n", skip_empty: bool = True, show_speaker=True) -> str:
    """
    Export the text content of all units as a single string.

    Args:
        separator: String used to separate units (default: newline).
        skip_empty: If True, skip units with empty or whitespace-only text.
        show_speaker: If True, add speaker info.

    Returns:
        Concatenated text of all units, separated by `separator`.
    """
    def _text_line(unit: TimedTextUnit) -> str:
        if show_speaker and unit.speaker:
            return f"[{unit.speaker}] {unit.text}"
        return unit.text

    texts = [
        _text_line(unit) for unit in self.units
        if not skip_empty or unit.text.strip()
    ]
    return separator.join(texts)
extend(units)

Add multiple units to the end.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
245
246
247
248
def extend(self, units: List[TimedTextUnit]):
    """Add multiple units to the end."""
    for unit in units:
        self.append(unit)
filter_by_min_duration(min_duration_ms)

Return a new TimedText object containing only units with a minimum duration.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
311
312
313
314
315
316
317
318
319
def filter_by_min_duration(self, min_duration_ms: int) -> "TimedText":
    """
    Return a new TimedText object containing only units with a minimum duration.
    """
    filtered_units = [
        unit for unit in self.units
        if unit.duration_ms >= min_duration_ms
    ]
    return self._new_with_units(filtered_units, self.granularity)
is_segment_granularity()

Return True if granularity is SEGMENT.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
211
212
213
def is_segment_granularity(self) -> bool:
    """Return True if granularity is SEGMENT."""
    return self.granularity == Granularity.SEGMENT
is_word_granularity()

Return True if granularity is WORD.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
215
216
217
def is_word_granularity(self) -> bool:
    """Return True if granularity is WORD."""
    return self.granularity == Granularity.WORD
iter()

Unified iterator over the units of the correct granularity.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
339
340
341
342
343
def iter(self) -> Iterator[TimedTextUnit]:
    """
    Unified iterator over the units of the correct granularity.
    """
    return iter(self.units)
iter_segments()

Iterate over segment-level units.

Raises:

Type Description
ValueError

If granularity is not SEGMENT.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
345
346
347
348
349
350
351
352
353
354
def iter_segments(self) -> Iterator[TimedTextUnit]:
    """
    Iterate over segment-level units.

    Raises:
        ValueError: If granularity is not SEGMENT.
    """
    if not self.is_segment_granularity():
        raise ValueError("Cannot call iter_segments() on TimedText with WORD granularity.")
    return iter(self.segments)
iter_words()

Iterate over word-level units.

Raises:

Type Description
ValueError

If granularity is not WORD.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
356
357
358
359
360
361
362
363
364
365
def iter_words(self) -> Iterator[TimedTextUnit]:
    """
    Iterate over word-level units.

    Raises:
        ValueError: If granularity is not WORD.
    """
    if not self.is_word_granularity():
        raise ValueError("Cannot call iter_words() on TimedText with SEGMENT granularity.")
    return iter(self.words)
merge(items) classmethod

Merge a list of TimedText objects of the same granularity into a single TimedText object.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
@classmethod
def merge(cls, items: List["TimedText"]) -> "TimedText":
    """
    Merge a list of TimedText objects of the same granularity into a single TimedText object.
    """
    if not items:
        raise ValueError("No TimedText objects to merge.")
    granularity = items[0].granularity
    for item in items:
        if item.granularity != granularity:
            raise ValueError("Cannot merge TimedText objects of different granularities.")
    all_units: List[TimedTextUnit] = []
    for item in items:
        all_units.extend(item.units)

    # Use the classmethod to generate with units
    return cls._new_with_units(all_units, granularity)
model_post_init(__context)

After initialization, sort units by start time and normalize durations.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
198
199
200
201
202
203
204
def model_post_init(self, __context) -> None:
    """
    After initialization, sort units by start time and normalize durations.
    """
    self.sort_by_start()
    for unit in self.units:
        unit.normalize()
set_all_speakers(speaker)

Set the same speaker for all units.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
260
261
262
263
def set_all_speakers(self, speaker: str) -> None:
    """Set the same speaker for all units."""
    for unit in self.units:
        unit.set_speaker(speaker)
set_speaker(index, speaker)

Set speaker for a specific unit by index.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
254
255
256
257
258
def set_speaker(self, index: int, speaker: str) -> None:
    """Set speaker for a specific unit by index."""
    if not (0 <= index < len(self.units)):
        raise IndexError(f"Index {index} out of range for units.")
    self.units[index].set_speaker(speaker)
shift(offset_ms)

Shift all units by a given offset in milliseconds.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
265
266
267
268
def shift(self, offset_ms: int) -> None:
    """Shift all units by a given offset in milliseconds."""
    for i, unit in enumerate(self.units):
        self.units[i] = unit.shift_time(offset_ms)
slice(start_ms, end_ms)

Return a new TimedText object containing only units within [start_ms, end_ms]. Units must overlap with the interval to be included.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
300
301
302
303
304
305
306
307
308
309
def slice(self, start_ms: int, end_ms: int) -> "TimedText":
    """
    Return a new TimedText object containing only units within [start_ms, end_ms].
    Units must overlap with the interval to be included.
    """
    sliced_units = [
        unit for unit in self.units
        if unit.end_ms > start_ms and unit.start_ms < end_ms
    ]
    return self._new_with_units(sliced_units, self.granularity)
sort_by_start()

Sort units by start time.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
270
271
272
def sort_by_start(self) -> None:
    """Sort units by start time."""
    self.units.sort(key=lambda unit: unit.start_ms)
TimedTextUnit

Bases: BaseModel

Represents a timed unit with timestamps.

A fundamental building block for subtitle and transcript processing that associates text content with start/end times and optional metadata. Can represent either a segment (phrase/sentence) or a word.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class TimedTextUnit(BaseModel):
    """
    Represents a timed unit with timestamps.

    A fundamental building block for subtitle and transcript processing that
    associates text content with start/end times and optional metadata.
    Can represent either a segment (phrase/sentence) or a word.
    """
    text: str = Field(..., description="The text content")
    start_ms: int = Field(..., description="Start time in milliseconds")
    end_ms: int = Field(..., description="End time in milliseconds")
    speaker: Optional[str] = Field(None, description="Speaker identifier if available")
    index: Optional[int] = Field(None, description="Entry index or sequence number")
    granularity: Granularity
    confidence: Optional[float] = Field(None, description="Optional confidence score")

    @property
    def duration_ms(self) -> int:
        """Get duration in milliseconds."""
        return self.end_ms - self.start_ms

    @property
    def start_sec(self) -> float:
        """Get start time in seconds."""
        return self.start_ms / 1000

    @property
    def end_sec(self) -> float:
        """Get end time in seconds."""
        return self.end_ms / 1000

    @property
    def duration_sec(self) -> float:
        """Get duration in seconds."""
        return self.duration_ms / 1000

    def shift_time(self, offset_ms: int) -> "TimedTextUnit":
        """Create a new TimedUnit with timestamps shifted by offset."""
        return self.model_copy(
            update={
                "start_ms": self.start_ms + offset_ms,
                "end_ms": self.end_ms + offset_ms
            }
        )

    def overlaps_with(self, other: "TimedTextUnit") -> bool:
        """Check if this unit overlaps with another."""
        return (self.start_ms <= other.end_ms and 
                other.start_ms <= self.end_ms)

    def set_speaker(self, speaker: str) -> None:
        """Set the speaker label."""
        self.speaker = speaker

    def normalize(self) -> None:
        """Normalize the duration of the segment to be nonzero"""
        if self.start_ms == self.end_ms:
            self.end_ms = self.start_ms + 1 # minimum duration 

    @field_validator("start_ms", "end_ms")
    @classmethod
    def _validate_time_non_negative(cls, v: int) -> int:
        if v < 0:
            raise ValueError("start_ms and end_ms must be non-negative.")
        return v

    @field_validator("end_ms")
    @classmethod
    def _validate_positive_duration(cls, end_ms: int, info) -> int:
        start_ms = info.data.get("start_ms")
        if start_ms is not None and end_ms < start_ms:
            raise ValueError(
                f"end_ms ({end_ms}) must be greater than start_ms ({start_ms})."
            )
        return end_ms

    @field_validator("text")
    @classmethod
    def _validate_word_text(cls, v: str, info):
        granularity = info.data.get("granularity", "segment")
        if granularity == "word" and (" " in v.strip()):
            raise ValueError(
                "Text for a word-level TimedUnit cannot contain whitespace."
            )
        return v
confidence = Field(None, description='Optional confidence score') class-attribute instance-attribute
duration_ms property

Get duration in milliseconds.

duration_sec property

Get duration in seconds.

end_ms = Field(..., description='End time in milliseconds') class-attribute instance-attribute
end_sec property

Get end time in seconds.

granularity instance-attribute
index = Field(None, description='Entry index or sequence number') class-attribute instance-attribute
speaker = Field(None, description='Speaker identifier if available') class-attribute instance-attribute
start_ms = Field(..., description='Start time in milliseconds') class-attribute instance-attribute
start_sec property

Get start time in seconds.

text = Field(..., description='The text content') class-attribute instance-attribute
normalize()

Normalize the duration of the segment to be nonzero

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
77
78
79
80
def normalize(self) -> None:
    """Normalize the duration of the segment to be nonzero"""
    if self.start_ms == self.end_ms:
        self.end_ms = self.start_ms + 1 # minimum duration 
overlaps_with(other)

Check if this unit overlaps with another.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
68
69
70
71
def overlaps_with(self, other: "TimedTextUnit") -> bool:
    """Check if this unit overlaps with another."""
    return (self.start_ms <= other.end_ms and 
            other.start_ms <= self.end_ms)
set_speaker(speaker)

Set the speaker label.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
73
74
75
def set_speaker(self, speaker: str) -> None:
    """Set the speaker label."""
    self.speaker = speaker
shift_time(offset_ms)

Create a new TimedUnit with timestamps shifted by offset.

Source code in src/tnh_scholar/audio_processing/timed_object/timed_text.py
59
60
61
62
63
64
65
66
def shift_time(self, offset_ms: int) -> "TimedTextUnit":
    """Create a new TimedUnit with timestamps shifted by offset."""
    return self.model_copy(
        update={
            "start_ms": self.start_ms + offset_ms,
            "end_ms": self.end_ms + offset_ms
        }
    )
TranscriptionService

Bases: ABC

Abstract base class defining the interface for transcription services.

This interface provides a standard way to interact with different transcription service providers (e.g., OpenAI Whisper, AssemblyAI).

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class TranscriptionService(ABC):
    """
    Abstract base class defining the interface for transcription services.

    This interface provides a standard way to interact with different
    transcription service providers (e.g., OpenAI Whisper, AssemblyAI).
    """

    @abstractmethod
    def transcribe(
        self,
        audio_file: Union[Path, BytesIO],
        options: Optional[Dict[str, Any]] = None
    ) -> TranscriptionResult:
        """
        Transcribe audio file to text.

        Args:
            audio_file: Path to audio file or file-like object
            options: Provider-specific options for transcription

        Returns:
            TranscriptionResult

        """
        pass

    @abstractmethod
    def get_result(self, job_id: str) -> TranscriptionResult:
        """
        Get results for an existing transcription job.

        Args:
            job_id: ID of the transcription job

        Returns:
            Dictionary containing transcription results in the same
            standardized format as transcribe()
        """
        pass

    @abstractmethod
    def transcribe_to_format(
        self,
        audio_file: Union[Path, BytesIO],
        format_type: str = "srt",
        transcription_options: Optional[Dict[str, Any]] = None,
        format_options: Optional[Dict[str, Any]] = None
    ) -> str:
        """
        Transcribe audio and return result in specified format.

        Args:
            audio_file: Path, file-like object, or URL of audio file
            format_type: Format type (e.g., "srt", "vtt", "text")
            transcription_options: Options for transcription
            format_options: Format-specific options

        Returns:
            String representation in the requested format
        """
        pass
get_result(job_id) abstractmethod

Get results for an existing transcription job.

Parameters:

Name Type Description Default
job_id str

ID of the transcription job

required

Returns:

Type Description
TranscriptionResult

Dictionary containing transcription results in the same

TranscriptionResult

standardized format as transcribe()

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
64
65
66
67
68
69
70
71
72
73
74
75
76
@abstractmethod
def get_result(self, job_id: str) -> TranscriptionResult:
    """
    Get results for an existing transcription job.

    Args:
        job_id: ID of the transcription job

    Returns:
        Dictionary containing transcription results in the same
        standardized format as transcribe()
    """
    pass
transcribe(audio_file, options=None) abstractmethod

Transcribe audio file to text.

Parameters:

Name Type Description Default
audio_file Union[Path, BytesIO]

Path to audio file or file-like object

required
options Optional[Dict[str, Any]]

Provider-specific options for transcription

None

Returns:

Type Description
TranscriptionResult

TranscriptionResult

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@abstractmethod
def transcribe(
    self,
    audio_file: Union[Path, BytesIO],
    options: Optional[Dict[str, Any]] = None
) -> TranscriptionResult:
    """
    Transcribe audio file to text.

    Args:
        audio_file: Path to audio file or file-like object
        options: Provider-specific options for transcription

    Returns:
        TranscriptionResult

    """
    pass
transcribe_to_format(audio_file, format_type='srt', transcription_options=None, format_options=None) abstractmethod

Transcribe audio and return result in specified format.

Parameters:

Name Type Description Default
audio_file Union[Path, BytesIO]

Path, file-like object, or URL of audio file

required
format_type str

Format type (e.g., "srt", "vtt", "text")

'srt'
transcription_options Optional[Dict[str, Any]]

Options for transcription

None
format_options Optional[Dict[str, Any]]

Format-specific options

None

Returns:

Type Description
str

String representation in the requested format

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@abstractmethod
def transcribe_to_format(
    self,
    audio_file: Union[Path, BytesIO],
    format_type: str = "srt",
    transcription_options: Optional[Dict[str, Any]] = None,
    format_options: Optional[Dict[str, Any]] = None
) -> str:
    """
    Transcribe audio and return result in specified format.

    Args:
        audio_file: Path, file-like object, or URL of audio file
        format_type: Format type (e.g., "srt", "vtt", "text")
        transcription_options: Options for transcription
        format_options: Format-specific options

    Returns:
        String representation in the requested format
    """
    pass
TranscriptionServiceFactory

Factory for creating transcription service instances.

This factory provides a standard way to create transcription service instances based on the provider name and configuration.

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
class TranscriptionServiceFactory:
    """
    Factory for creating transcription service instances.

    This factory provides a standard way to create transcription service
    instances based on the provider name and configuration.
    """

    # Mapping provider names to implementation classes
    # Classes will be imported lazily when needed
    _PROVIDER_MAP: Dict[str, Callable[..., TranscriptionService]] = {}

    @classmethod
    def register_provider(
        cls, 
        name: str, 
        provider_class: Callable[..., TranscriptionService]
    ) -> None:
        """
        Register a provider implementation with the factory.

        Args:
            name: Provider name (lowercase)
            provider_class: Provider implementation class or factory function

        Example:
            >>> from my_module import MyTranscriptionService
            >>> TranscriptionServiceFactory.register_provider("my_provider", MyTranscriptionService)
        """  
        cls._PROVIDER_MAP[name.lower()] = provider_class

    @classmethod
    def create_service(
        cls,
        provider: str = "assemblyai",
        api_key: Optional[str] = None,
        **kwargs
    ) -> TranscriptionService:
        """
        Create a transcription service instance.

        Args:
            provider: Service provider name (e.g., "whisper", "assemblyai")
            api_key: API key for the service
            **kwargs: Additional provider-specific configuration

        Returns:
            TranscriptionService instance

        Raises:
            ValueError: If the provider is not supported
            ImportError: If the provider module cannot be imported
        """
        provider = provider.lower()

        # Initialize provider map if empty
        if not cls._PROVIDER_MAP:
            # Import lazily to avoid circular imports
            from .assemblyai_service import AAITranscriptionService
            from .whisper_service import WhisperTranscriptionService

            cls._PROVIDER_MAP = {
                "whisper": WhisperTranscriptionService,
                "assemblyai": AAITranscriptionService,
            }

        # Get the provider implementation
        provider_class = cls._PROVIDER_MAP.get(provider)

        if provider_class is None:
            raise ValueError(f"Unsupported transcription provider: {provider}")

        # Create and return the service instance
        return provider_class(api_key=api_key, **kwargs)
create_service(provider='assemblyai', api_key=None, **kwargs) classmethod

Create a transcription service instance.

Parameters:

Name Type Description Default
provider str

Service provider name (e.g., "whisper", "assemblyai")

'assemblyai'
api_key Optional[str]

API key for the service

None
**kwargs

Additional provider-specific configuration

{}

Returns:

Type Description
TranscriptionService

TranscriptionService instance

Raises:

Type Description
ValueError

If the provider is not supported

ImportError

If the provider module cannot be imported

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
@classmethod
def create_service(
    cls,
    provider: str = "assemblyai",
    api_key: Optional[str] = None,
    **kwargs
) -> TranscriptionService:
    """
    Create a transcription service instance.

    Args:
        provider: Service provider name (e.g., "whisper", "assemblyai")
        api_key: API key for the service
        **kwargs: Additional provider-specific configuration

    Returns:
        TranscriptionService instance

    Raises:
        ValueError: If the provider is not supported
        ImportError: If the provider module cannot be imported
    """
    provider = provider.lower()

    # Initialize provider map if empty
    if not cls._PROVIDER_MAP:
        # Import lazily to avoid circular imports
        from .assemblyai_service import AAITranscriptionService
        from .whisper_service import WhisperTranscriptionService

        cls._PROVIDER_MAP = {
            "whisper": WhisperTranscriptionService,
            "assemblyai": AAITranscriptionService,
        }

    # Get the provider implementation
    provider_class = cls._PROVIDER_MAP.get(provider)

    if provider_class is None:
        raise ValueError(f"Unsupported transcription provider: {provider}")

    # Create and return the service instance
    return provider_class(api_key=api_key, **kwargs)
register_provider(name, provider_class) classmethod

Register a provider implementation with the factory.

Parameters:

Name Type Description Default
name str

Provider name (lowercase)

required
provider_class Callable[..., TranscriptionService]

Provider implementation class or factory function

required
Example

from my_module import MyTranscriptionService TranscriptionServiceFactory.register_provider("my_provider", MyTranscriptionService)

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
@classmethod
def register_provider(
    cls, 
    name: str, 
    provider_class: Callable[..., TranscriptionService]
) -> None:
    """
    Register a provider implementation with the factory.

    Args:
        name: Provider name (lowercase)
        provider_class: Provider implementation class or factory function

    Example:
        >>> from my_module import MyTranscriptionService
        >>> TranscriptionServiceFactory.register_provider("my_provider", MyTranscriptionService)
    """  
    cls._PROVIDER_MAP[name.lower()] = provider_class
patch_whisper_options(options, file_extension)

Patch routine to ensure 'file_extension' is present in transcription options dict. This is a workaround for OpenAI Whisper API, which requires file-like objects to have a filename/extension. Only allows known audio extensions.

Parameters:

Name Type Description Default
options Optional[Dict[str, Any]]

Transcription options dictionary (will not be mutated)

required
file_extension str

File extension string (with or without leading dot)

required

Returns:

Type Description
Dict[str, Any]

New options dictionary with 'file_extension' set appropriately

Raises:

Type Description
ValueError

If file_extension is not in the allowed list

Source code in src/tnh_scholar/audio_processing/transcription/patches.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def patch_whisper_options(
    options: Optional[Dict[str, Any]],
    file_extension: str
    ) -> Dict[str, Any]:
    """
    Patch routine to ensure 'file_extension' is present in transcription options dict.
    This is a workaround for OpenAI Whisper API, which requires file-like objects to have a
    filename/extension. Only allows known audio extensions.

    Args:
        options: Transcription options dictionary (will not be mutated)
        file_extension: File extension string (with or without leading dot)

    Returns:
        New options dictionary with 'file_extension' set appropriately

    Raises:
        ValueError: If file_extension is not in the allowed list
    """
    patched = dict(options) if options is not None else {}
    ext = file_extension.lstrip('.')
    if ext.lower() not in _ALLOWED_EXTENSIONS:
        raise ValueError(
            f"Unsupported file extension '{ext}'. Allowed extensions: {_ALLOWED_EXTENSIONS}"
        )
    patched['file_extension'] = ext
    return patched
assemblyai_service

AssemblyAI implementation of the TranscriptionService interface.

This module provides a complete implementation of the TranscriptionService interface using the AssemblyAI Python SDK, with support for all major features including:

  • Transcription with configurable options
  • Speaker diarization
  • Automatic language detection
  • Audio intelligence features
  • Subtitle generation
  • Regional endpoint support
  • Webhook callbacks

The implementation follows a modular design with single-action methods and supports both synchronous and asynchronous usage patterns.

logger = get_child_logger(__name__) module-attribute
AAIConfig dataclass

Comprehensive configuration for AssemblyAI transcription service.

This class contains all configurable options for the AssemblyAI API, organized by feature category.

Source code in src/tnh_scholar/audio_processing/transcription/assemblyai_service.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
@dataclass
class AAIConfig:
    """
    Comprehensive configuration for AssemblyAI transcription service.

    This class contains all configurable options for the AssemblyAI API,
    organized by feature category.
    """

    # Base configuration
    api_key: Optional[str] = None
    use_eu_endpoint: bool = False

    # Connection configuration
    polling_interval: int = 4

    # Core transcription configuration
    speech_model: SpeechModel = SpeechModel.BEST
    language_code: Optional[str] = None
    language_detection: bool = True
    dual_channel: bool = False

    # Text formatting options
    format_text: bool = True
    punctuate: bool = True
    disfluencies: bool = False
    filter_profanity: bool = False

    # subtitle options
    chars_per_caption: int = 60

    # Speaker options
    speaker_labels: bool = True
    speakers_expected: Optional[int] = None

    # Audio channel options
    custom_spelling: Dict[str, str] = field(default_factory=dict)
    word_boost: List[str] = field(default_factory=list)

    # Audio intelligence configuration
    auto_chapters: bool = False
    auto_highlights: bool = False
    entity_detection: bool = False
    iab_categories: bool = False
    sentiment_analysis: bool = False
    summarization: bool = False
    content_safety: bool = False

    # Callback options (Webhook functionality currently not implemented)
    # The transcribe_asynch method provides asynchronous processing
    webhook_url: Optional[str] = None
    webhook_auth_header_name: Optional[str] = None
    webhook_auth_header_value: Optional[str] = None
api_key = None class-attribute instance-attribute
auto_chapters = False class-attribute instance-attribute
auto_highlights = False class-attribute instance-attribute
chars_per_caption = 60 class-attribute instance-attribute
content_safety = False class-attribute instance-attribute
custom_spelling = field(default_factory=dict) class-attribute instance-attribute
disfluencies = False class-attribute instance-attribute
dual_channel = False class-attribute instance-attribute
entity_detection = False class-attribute instance-attribute
filter_profanity = False class-attribute instance-attribute
format_text = True class-attribute instance-attribute
iab_categories = False class-attribute instance-attribute
language_code = None class-attribute instance-attribute
language_detection = True class-attribute instance-attribute
polling_interval = 4 class-attribute instance-attribute
punctuate = True class-attribute instance-attribute
sentiment_analysis = False class-attribute instance-attribute
speaker_labels = True class-attribute instance-attribute
speakers_expected = None class-attribute instance-attribute
speech_model = SpeechModel.BEST class-attribute instance-attribute
summarization = False class-attribute instance-attribute
use_eu_endpoint = False class-attribute instance-attribute
webhook_auth_header_name = None class-attribute instance-attribute
webhook_auth_header_value = None class-attribute instance-attribute
webhook_url = None class-attribute instance-attribute
word_boost = field(default_factory=list) class-attribute instance-attribute
__init__(api_key=None, use_eu_endpoint=False, polling_interval=4, speech_model=SpeechModel.BEST, language_code=None, language_detection=True, dual_channel=False, format_text=True, punctuate=True, disfluencies=False, filter_profanity=False, chars_per_caption=60, speaker_labels=True, speakers_expected=None, custom_spelling=dict(), word_boost=list(), auto_chapters=False, auto_highlights=False, entity_detection=False, iab_categories=False, sentiment_analysis=False, summarization=False, content_safety=False, webhook_url=None, webhook_auth_header_name=None, webhook_auth_header_value=None)
AAITranscriptionService

Bases: TranscriptionService

AssemblyAI implementation of the TranscriptionService interface.

Provides comprehensive access to AssemblyAI's transcription services with support for all major features through the official Python SDK.

Source code in src/tnh_scholar/audio_processing/transcription/assemblyai_service.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
class AAITranscriptionService(TranscriptionService):
    """
    AssemblyAI implementation of the TranscriptionService interface.

    Provides comprehensive access to AssemblyAI's transcription services
    with support for all major features through the official Python SDK.
    """

    def __init__(
        self, 
        api_key: Optional[str] = None, 
        options: Optional[Dict[str, Any]] = None,
        ):
        """
        Initialize the AssemblyAI transcription service.

        Args:
            api_key: AssemblyAI API key (defaults to ASSEMBLYAI_API_KEY env var)
            config: Comprehensive configuration options
        """
        # Initialize format converter for fallback cases
        self.format_converter = FormatConverter()

        # Set and validate configuration
        self.config = AAIConfig()

        # Configure SDK
        self._configure_sdk(api_key)

        # Create transcriber instance
        self.transcriber = aai.Transcriber(
            config=self._create_transcription_config(options)
            )

        logger.debug("Initialized AssemblyAI service with SDK")

    def _configure_sdk(self, api_key: Optional[str] = None) -> None:
        """
        Configure the AssemblyAI SDK with API key and regional settings.

        Args:
            api_key: AssemblyAI API key

        Raises:
            ValueError: If no API key is provided or found in environment
        """
        # Set API key - priority: parameter > config > env var
        api_key = api_key or self.config.api_key or os.getenv("ASSEMBLYAI_API_KEY")
        if not api_key:
            raise ValueError(
                "AssemblyAI API key is required. Set ASSEMBLYAI_API_KEY environment "
                "variable, pass as api_key parameter, or include in config."
            )

        # Configure SDK settings
        aai.settings.api_key = api_key
        aai.settings.polling_interval = self.config.polling_interval



        # Configure regional settings
        if self.config.use_eu_endpoint:
            aai.settings.base_url = "https://api.eu.assemblyai.com/v2"
            logger.debug("Using EU endpoint for AssemblyAI API")

    def _create_transcription_config(
        self, 
        options: Optional[Dict[str, Any]] = None
    ) -> aai.TranscriptionConfig:
        """
        Create a TranscriptionConfig object from configuration and options.

        Args:
            options: Additional options to override configuration

        Returns:
            Configured TranscriptionConfig object
        """
        # Start with empty config
        config_params = {}

        # Add core settings
        if self.config.speech_model == SpeechModel.NANO:
            config_params["speech_model"] = "nano"

        if self.config.language_code:
            config_params["language_code"] = self.config.language_code

        config_params["language_detection"] = self.config.language_detection
        config_params["dual_channel"] = self.config.dual_channel

        # Add text formatting options
        config_params["format_text"] = self.config.format_text
        config_params["punctuate"] = self.config.punctuate
        config_params["disfluencies"] = self.config.disfluencies
        config_params["filter_profanity"] = self.config.filter_profanity

        # Add speaker options
        config_params["speaker_labels"] = self.config.speaker_labels
        if self.config.speakers_expected is not None:
            config_params["speakers_expected"] = self.config.speakers_expected

        # Add audio intelligence options
        config_params["auto_chapters"] = self.config.auto_chapters
        config_params["auto_highlights"] = self.config.auto_highlights
        config_params["entity_detection"] = self.config.entity_detection
        config_params["iab_categories"] = self.config.iab_categories
        config_params["sentiment_analysis"] = self.config.sentiment_analysis
        config_params["summarization"] = self.config.summarization
        config_params["content_safety"] = self.config.content_safety

        # Add custom vocabulary options
        if self.config.word_boost:
            config_params["word_boost"] = self.config.word_boost

        # Add custom spelling
        if self.config.custom_spelling:
            config_params["custom_spelling"] = self.config.custom_spelling

        # Add webhook config
        if self.config.webhook_url:
            config_params["webhook_url"] = self.config.webhook_url
            if self.config.webhook_auth_header_name and \
                self.config.webhook_auth_header_value:
                config_params["webhook_auth_header_name"] = \
                    self.config.webhook_auth_header_name
                config_params["webhook_auth_header_value"] = \
                    self.config.webhook_auth_header_value

        # Override with any provided options
        if options:
            config_params |= options

        # Create config object
        return aai.TranscriptionConfig(**config_params)

    def _get_file_path(
        self, 
        audio_file: Union[Path, BinaryIO, str]
        ) -> Union[BinaryIO, str]:
        """
        Get appropriate file path for different input types.

        Args:
            audio_file: Path, file-like object, or URL of audio file

        Returns:
            Path or string for SDK

        Raises:
            TypeError: If input type is not supported
        """
        # Handle Path objects
        if isinstance(audio_file, Path):
            return str(audio_file)

        # Handle URLs
        if isinstance(audio_file, str) and (
            audio_file.startswith("http://") or 
            audio_file.startswith("https://")
        ):
            return audio_file

        # Handle file-like objects
        if hasattr(audio_file, "read"):
            # SDK handles file-like objects directly
            return audio_file

        raise TypeError(f"Unsupported audio file type: {type(audio_file)}")

    def _extract_words(self, transcript: aai.Transcript) -> TimedText:
        """
        Extract words with timestamps from transcript and return a TimedText object.

        Args:
            transcript: AssemblyAI transcript object

        Returns:
            TimedText object containing word-level units
        """
        if not transcript.words:
            raise ValueError(f"Transcript object has no words: {transcript}")

        units = [
            TimedTextUnit(
                index=None,
                granularity=Granularity.WORD,
                speaker=word.speaker,
                text=word.text,
                start_ms=word.start,
                end_ms=word.end,
                confidence=word.confidence,
            )
            for word in transcript.words
        ]

        # TimedText performs its own internal validation
        return TimedText(words=units, granularity=Granularity.WORD)

    def _extract_utterances(self, transcript: aai.Transcript) -> TimedText:
        """
        Extract utterances (speaker segments) from transcript and return a TimedText object.

        Args:
            transcript: AssemblyAI transcript object

        Returns:
            TimedText object containing utterance-level units
        """
        if not (utterances := getattr(transcript, "utterances", None)):
            # Return an empty TimedText if diarization wasn't requested
            return TimedText(segments=[], granularity=Granularity.SEGMENT)

        units = [
            TimedTextUnit(
                index=None,
                granularity=Granularity.SEGMENT,
                text=utterance.text,
                start_ms=utterance.start,
                end_ms=utterance.end,
                speaker=utterance.speaker,
                confidence=utterance.confidence,
            )
            for utterance in utterances
        ]

        return TimedText(segments=units, granularity=Granularity.SEGMENT)

    def _extract_audio_intelligence(self, transcript: aai.Transcript) -> Dict[str, Any]:
        """
        Extract audio intelligence features from transcript.

        Args:
            transcript: AssemblyAI transcript object

        Returns:
            Dictionary of audio intelligence features
        """
        intelligence = {}

        # Extract auto chapters
        if hasattr(transcript, "chapters") and transcript.chapters:
            chapters_data = []
            chapters_data.extend(
                {
                    "summary": chapter.summary,
                    "headline": chapter.headline,
                    "start_ms": chapter.start,
                    "end_ms": chapter.end,
                }
                for chapter in transcript.chapters
            )
            intelligence["chapters"] = chapters_data

        # Extract sentiment analysis
        if hasattr(transcript, "sentiment_analysis") and transcript.sentiment_analysis:
            sentiment_data = []
            sentiment_data.extend(
                {
                    "text": sentiment.text,
                    "sentiment": sentiment.sentiment,
                    "confidence": sentiment.confidence,
                    "start_ms": sentiment.start,
                    "end_ms": sentiment.end,
                }
                for sentiment in transcript.sentiment_analysis
            )
            intelligence["sentiment_analysis"] = sentiment_data

        # Extract entity detection
        if hasattr(transcript, "entities") and transcript.entities:
            entities_data = []
            entities_data.extend(
                {
                    "text": entity.text,
                    "entity_type": entity.entity_type,
                    "start_ms": entity.start,
                    "end_ms": entity.end,
                }
                for entity in transcript.entities
            )
            intelligence["entities"] = entities_data

        # Extract topics (IAB categories)
        if hasattr(transcript, "iab_categories") and transcript.iab_categories:
            topics_data = {
                "results": [],
                "summary": transcript.iab_categories.summary
            }

            if not transcript.iab_categories.results:
                return topics_data

            for result in transcript.iab_categories.results:
                topics_data["results"].append({
                    "text": result.text,
                    "labels": [
                        {"label": label.label, "relevance": label.relevance}
                        for label in result.labels
                    ],
                    "timestamp": {
                        "start": result.timestamp.start,
                        "end": result.timestamp.end
                    }
                })

            intelligence["topics"] = topics_data

        # Extract auto highlights
        if hasattr(transcript, "auto_highlights") and transcript.auto_highlights:
            intelligence["highlights"] = {
                "results": transcript.auto_highlights.results,
                "status": transcript.auto_highlights.status
            }

        return intelligence

    def standardize_result(self, transcript: aai.Transcript) -> TranscriptionResult:
        """
        Standardize AssemblyAI transcript to match common format.

        Args:
            transcript: AssemblyAI transcript object

        Returns:
            Standardized result dictionary
        """
        # Extract words and utterances as TimedText
        words = self._extract_words(transcript)
        utterances = self._extract_utterances(transcript)

        language = self.config.language_code or \
                ("auto" if self.config.language_detection else "unknown")

        return TranscriptionResult(
            text=transcript.text or "",
            language=language,
            word_timing=words,
            utterance_timing=utterances,
            confidence=getattr(transcript, "confidence", 0.0),
            audio_duration_ms=getattr(transcript, "audio_duration", 0),
            transcript_id=transcript.id,
            status=transcript.status,
            raw_result=transcript.json_response,
        )

    def transcribe(
        self,
        audio_file: Union[Path, BinaryIO, str],
        options: Optional[Dict[str, Any]] = None
    ) -> TranscriptionResult:
        """
        Transcribe audio file to text using AssemblyAI's synchronous SDK approach.

        This method handles:
        - File paths
        - File-like objects
        - URLs

        Args:
            audio_file: Path, file-like object, or URL of audio file
            options: Provider-specific options for transcription

        Returns:
            Dictionary containing standardized transcription results
        """
        try:
            transcript = self._gen_transcript(options, audio_file)

            # Standardize the result format
            return self.standardize_result(transcript)

        except Exception as e:
            logger.error(f"Transcription failed: {e}")
            raise RuntimeError(f"AssemblyAI transcription failed: {e}") from e

    def transcribe_async(
        self,
        audio_file: Union[Path, BinaryIO, str],
        options: Optional[Dict[str, Any]] = None
    ) ->  Future:
        """
        Submit an asynchronous transcription job using AssemblyAI's SDK.

        This method submits a transcription job and returns immediately with
        a transcript ID that can be used to retrieve results later.

        Args:
            audio_file: Path, file-like object, or URL of audio file
            options: Provider-specific options for transcription

        Returns:
            String containing the transcript ID for later retrieval

        Notes:
            The SDK's submit method returns a Future object, but this method
            extracts just the transcript ID for simpler handling.
        """
        try:
            # Create configuration with options
            tx_config = self._create_transcription_config(options)

            # Get file path/object in the right format
            file_path = self._get_file_path(audio_file)

            logger.info("Submitting asynchronous transcription with AssemblyAI SDK")

            # Use the SDK's asynchronous submit method
            # This returns a Future object containing a Transcript
            return self.transcriber.transcribe_async(file_path, config=tx_config)

        except Exception as e:
            logger.error(f"Transcription submission failed: {e}")
            raise RuntimeError(f"AssemblyAI transcription submission failed: {e}") \
                from e

    def get_result(self, job_id: str) -> TranscriptionResult:
        """
        Get results for an existing transcription job.

        This method blocks until the transcript is retrieved.

        Args:
            job_id: ID of the transcription job

        Returns:
            Dictionary containing transcription results
        """
        try:
            # Use the SDK's get_by_id method to retrieve the transcript
            # This blocks until the transcript is retrieved
            transcript = aai.Transcript.get_by_id(job_id)

            # Standardize the result format
            return self.standardize_result(transcript)

        except Exception as e:
            logger.error(f"Failed to retrieve transcript {job_id}: {e}")
            raise RuntimeError(f"Failed to retrieve transcript: {e}") from e

    def get_subtitles(
        self, 
        transcript_id: str, 
        format_type: str = "srt",
    ) -> str:
        """
        Get subtitles directly from AssemblyAI.

        Args:
            transcript_id: ID of the transcription job
            format_type: Format type ("srt" or "vtt")
            chars_per_caption: Maximum characters per caption

        Returns:
            String representation in the requested format

        Raises:
            ValueError: If the format type is not supported
        """
        chars_per_caption = self.config.chars_per_caption

        format_type = format_type.lower()

        if format_type not in ["srt", "vtt"]:
            raise ValueError(
                f"Unsupported subtitle format: {format_type}. "
                "Supported formats: srt, vtt"
                )
        # Create transcript object from ID
        transcript = aai.Transcript(transcript_id=transcript_id)

        # Get subtitles in requested format
        if format_type == "srt":
            return transcript.export_subtitles_srt(chars_per_caption=chars_per_caption)
        else:  # format_type == "vtt"
            return transcript.export_subtitles_vtt(chars_per_caption=chars_per_caption)

    def transcribe_to_format(
        self,
        audio_file: Union[Path, BinaryIO, str],
        format_type: str = "srt",
        transcription_options: Optional[Dict[str, Any]] = None,
        format_options: Optional[Dict[str, Any]] = None
    ) -> str:
        """
        Transcribe audio and return result in specified format.

        Takes advantage of the direct subtitle generation
        functionality when requesting SRT or VTT formats.

        Args:
            audio_file: Path, file-like object, or URL of audio file
            format_type: Format type (e.g., "srt", "vtt", "text")
            transcription_options: Options for transcription
            format_options: Format-specific options

        Returns:
            String representation in the requested format
        """
        format_type = format_type.lower()
        chars_per_caption = format_options.get(
            'chars_per_caption', self.config.chars_per_caption) \
            if format_options else self.config.chars_per_caption

        transcript = self._gen_transcript(
                transcription_options, audio_file
            )

        # Check if we need direct subtitle generation
        if format_type == "srt":  
            return transcript.export_subtitles_srt(chars_per_caption=chars_per_caption)
        elif format_type == "vtt":
            return transcript.export_subtitles_vtt(chars_per_caption=chars_per_caption)

        # For other formats, use the format converter
        # First get a normal transcription result
        result = self.transcribe(audio_file, transcription_options)

        # Then convert to the requested format
        return self.format_converter.convert(
            result, format_type, format_options or {}
        )

    def _gen_transcript(self, transcription_options, audio_file):
        # Create configuration with options
        tx_config = self._create_transcription_config(transcription_options)

        # Get file path/object in the right format
        file_path = self._get_file_path(audio_file)

        logger.info("Starting synchronous transcription with AssemblyAI SDK")

        # Use the SDK's synchronous transcribe method
        # This will block until transcription is complete
        return self.transcriber.transcribe(file_path, config=tx_config)
config = AAIConfig() instance-attribute
format_converter = FormatConverter() instance-attribute
transcriber = aai.Transcriber(config=(self._create_transcription_config(options))) instance-attribute
__init__(api_key=None, options=None)

Initialize the AssemblyAI transcription service.

Parameters:

Name Type Description Default
api_key Optional[str]

AssemblyAI API key (defaults to ASSEMBLYAI_API_KEY env var)

None
config

Comprehensive configuration options

required
Source code in src/tnh_scholar/audio_processing/transcription/assemblyai_service.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def __init__(
    self, 
    api_key: Optional[str] = None, 
    options: Optional[Dict[str, Any]] = None,
    ):
    """
    Initialize the AssemblyAI transcription service.

    Args:
        api_key: AssemblyAI API key (defaults to ASSEMBLYAI_API_KEY env var)
        config: Comprehensive configuration options
    """
    # Initialize format converter for fallback cases
    self.format_converter = FormatConverter()

    # Set and validate configuration
    self.config = AAIConfig()

    # Configure SDK
    self._configure_sdk(api_key)

    # Create transcriber instance
    self.transcriber = aai.Transcriber(
        config=self._create_transcription_config(options)
        )

    logger.debug("Initialized AssemblyAI service with SDK")
get_result(job_id)

Get results for an existing transcription job.

This method blocks until the transcript is retrieved.

Parameters:

Name Type Description Default
job_id str

ID of the transcription job

required

Returns:

Type Description
TranscriptionResult

Dictionary containing transcription results

Source code in src/tnh_scholar/audio_processing/transcription/assemblyai_service.py
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
def get_result(self, job_id: str) -> TranscriptionResult:
    """
    Get results for an existing transcription job.

    This method blocks until the transcript is retrieved.

    Args:
        job_id: ID of the transcription job

    Returns:
        Dictionary containing transcription results
    """
    try:
        # Use the SDK's get_by_id method to retrieve the transcript
        # This blocks until the transcript is retrieved
        transcript = aai.Transcript.get_by_id(job_id)

        # Standardize the result format
        return self.standardize_result(transcript)

    except Exception as e:
        logger.error(f"Failed to retrieve transcript {job_id}: {e}")
        raise RuntimeError(f"Failed to retrieve transcript: {e}") from e
get_subtitles(transcript_id, format_type='srt')

Get subtitles directly from AssemblyAI.

Parameters:

Name Type Description Default
transcript_id str

ID of the transcription job

required
format_type str

Format type ("srt" or "vtt")

'srt'
chars_per_caption

Maximum characters per caption

required

Returns:

Type Description
str

String representation in the requested format

Raises:

Type Description
ValueError

If the format type is not supported

Source code in src/tnh_scholar/audio_processing/transcription/assemblyai_service.py
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
def get_subtitles(
    self, 
    transcript_id: str, 
    format_type: str = "srt",
) -> str:
    """
    Get subtitles directly from AssemblyAI.

    Args:
        transcript_id: ID of the transcription job
        format_type: Format type ("srt" or "vtt")
        chars_per_caption: Maximum characters per caption

    Returns:
        String representation in the requested format

    Raises:
        ValueError: If the format type is not supported
    """
    chars_per_caption = self.config.chars_per_caption

    format_type = format_type.lower()

    if format_type not in ["srt", "vtt"]:
        raise ValueError(
            f"Unsupported subtitle format: {format_type}. "
            "Supported formats: srt, vtt"
            )
    # Create transcript object from ID
    transcript = aai.Transcript(transcript_id=transcript_id)

    # Get subtitles in requested format
    if format_type == "srt":
        return transcript.export_subtitles_srt(chars_per_caption=chars_per_caption)
    else:  # format_type == "vtt"
        return transcript.export_subtitles_vtt(chars_per_caption=chars_per_caption)
standardize_result(transcript)

Standardize AssemblyAI transcript to match common format.

Parameters:

Name Type Description Default
transcript Transcript

AssemblyAI transcript object

required

Returns:

Type Description
TranscriptionResult

Standardized result dictionary

Source code in src/tnh_scholar/audio_processing/transcription/assemblyai_service.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
def standardize_result(self, transcript: aai.Transcript) -> TranscriptionResult:
    """
    Standardize AssemblyAI transcript to match common format.

    Args:
        transcript: AssemblyAI transcript object

    Returns:
        Standardized result dictionary
    """
    # Extract words and utterances as TimedText
    words = self._extract_words(transcript)
    utterances = self._extract_utterances(transcript)

    language = self.config.language_code or \
            ("auto" if self.config.language_detection else "unknown")

    return TranscriptionResult(
        text=transcript.text or "",
        language=language,
        word_timing=words,
        utterance_timing=utterances,
        confidence=getattr(transcript, "confidence", 0.0),
        audio_duration_ms=getattr(transcript, "audio_duration", 0),
        transcript_id=transcript.id,
        status=transcript.status,
        raw_result=transcript.json_response,
    )
transcribe(audio_file, options=None)

Transcribe audio file to text using AssemblyAI's synchronous SDK approach.

This method handles: - File paths - File-like objects - URLs

Parameters:

Name Type Description Default
audio_file Union[Path, BinaryIO, str]

Path, file-like object, or URL of audio file

required
options Optional[Dict[str, Any]]

Provider-specific options for transcription

None

Returns:

Type Description
TranscriptionResult

Dictionary containing standardized transcription results

Source code in src/tnh_scholar/audio_processing/transcription/assemblyai_service.py
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
def transcribe(
    self,
    audio_file: Union[Path, BinaryIO, str],
    options: Optional[Dict[str, Any]] = None
) -> TranscriptionResult:
    """
    Transcribe audio file to text using AssemblyAI's synchronous SDK approach.

    This method handles:
    - File paths
    - File-like objects
    - URLs

    Args:
        audio_file: Path, file-like object, or URL of audio file
        options: Provider-specific options for transcription

    Returns:
        Dictionary containing standardized transcription results
    """
    try:
        transcript = self._gen_transcript(options, audio_file)

        # Standardize the result format
        return self.standardize_result(transcript)

    except Exception as e:
        logger.error(f"Transcription failed: {e}")
        raise RuntimeError(f"AssemblyAI transcription failed: {e}") from e
transcribe_async(audio_file, options=None)

Submit an asynchronous transcription job using AssemblyAI's SDK.

This method submits a transcription job and returns immediately with a transcript ID that can be used to retrieve results later.

Parameters:

Name Type Description Default
audio_file Union[Path, BinaryIO, str]

Path, file-like object, or URL of audio file

required
options Optional[Dict[str, Any]]

Provider-specific options for transcription

None

Returns:

Type Description
Future

String containing the transcript ID for later retrieval

Notes

The SDK's submit method returns a Future object, but this method extracts just the transcript ID for simpler handling.

Source code in src/tnh_scholar/audio_processing/transcription/assemblyai_service.py
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
def transcribe_async(
    self,
    audio_file: Union[Path, BinaryIO, str],
    options: Optional[Dict[str, Any]] = None
) ->  Future:
    """
    Submit an asynchronous transcription job using AssemblyAI's SDK.

    This method submits a transcription job and returns immediately with
    a transcript ID that can be used to retrieve results later.

    Args:
        audio_file: Path, file-like object, or URL of audio file
        options: Provider-specific options for transcription

    Returns:
        String containing the transcript ID for later retrieval

    Notes:
        The SDK's submit method returns a Future object, but this method
        extracts just the transcript ID for simpler handling.
    """
    try:
        # Create configuration with options
        tx_config = self._create_transcription_config(options)

        # Get file path/object in the right format
        file_path = self._get_file_path(audio_file)

        logger.info("Submitting asynchronous transcription with AssemblyAI SDK")

        # Use the SDK's asynchronous submit method
        # This returns a Future object containing a Transcript
        return self.transcriber.transcribe_async(file_path, config=tx_config)

    except Exception as e:
        logger.error(f"Transcription submission failed: {e}")
        raise RuntimeError(f"AssemblyAI transcription submission failed: {e}") \
            from e
transcribe_to_format(audio_file, format_type='srt', transcription_options=None, format_options=None)

Transcribe audio and return result in specified format.

Takes advantage of the direct subtitle generation functionality when requesting SRT or VTT formats.

Parameters:

Name Type Description Default
audio_file Union[Path, BinaryIO, str]

Path, file-like object, or URL of audio file

required
format_type str

Format type (e.g., "srt", "vtt", "text")

'srt'
transcription_options Optional[Dict[str, Any]]

Options for transcription

None
format_options Optional[Dict[str, Any]]

Format-specific options

None

Returns:

Type Description
str

String representation in the requested format

Source code in src/tnh_scholar/audio_processing/transcription/assemblyai_service.py
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
def transcribe_to_format(
    self,
    audio_file: Union[Path, BinaryIO, str],
    format_type: str = "srt",
    transcription_options: Optional[Dict[str, Any]] = None,
    format_options: Optional[Dict[str, Any]] = None
) -> str:
    """
    Transcribe audio and return result in specified format.

    Takes advantage of the direct subtitle generation
    functionality when requesting SRT or VTT formats.

    Args:
        audio_file: Path, file-like object, or URL of audio file
        format_type: Format type (e.g., "srt", "vtt", "text")
        transcription_options: Options for transcription
        format_options: Format-specific options

    Returns:
        String representation in the requested format
    """
    format_type = format_type.lower()
    chars_per_caption = format_options.get(
        'chars_per_caption', self.config.chars_per_caption) \
        if format_options else self.config.chars_per_caption

    transcript = self._gen_transcript(
            transcription_options, audio_file
        )

    # Check if we need direct subtitle generation
    if format_type == "srt":  
        return transcript.export_subtitles_srt(chars_per_caption=chars_per_caption)
    elif format_type == "vtt":
        return transcript.export_subtitles_vtt(chars_per_caption=chars_per_caption)

    # For other formats, use the format converter
    # First get a normal transcription result
    result = self.transcribe(audio_file, transcription_options)

    # Then convert to the requested format
    return self.format_converter.convert(
        result, format_type, format_options or {}
    )
SpeechModel

Bases: str, Enum

Supported AssemblyAI speech models.

Source code in src/tnh_scholar/audio_processing/transcription/assemblyai_service.py
41
42
43
44
class SpeechModel(str, Enum):
    """Supported AssemblyAI speech models."""
    BEST = "best"
    NANO = "nano"
BEST = 'best' class-attribute instance-attribute
NANO = 'nano' class-attribute instance-attribute
format_converter
tnh_scholar.audio_processing.transcription.format_converter

Thin facade that turns raw transcription-service output dictionaries into the formats requested by callers (plain-text, SRT - VTT coming later).

Core heavy lifting now lives in:

  • TimedText / TimedTextUnit - canonical internal representation
  • SegmentBuilder - word-level -> sentence/segment chunking
  • SRTProcessor - rendering to .srt

Only one public method remains: :py:meth:FormatConverter.convert.

logger = get_child_logger(__name__) module-attribute
FormatConverter

Convert a raw transcription result to text, SRT, or (placeholder) VTT.

The raw result must follow the loose schema - {"utterances": [...]} -> already speaker-segmented - {"words": [...]} -> word-level; we chunk via :class:SegmentBuilder - {"text": "...", "audio_duration_ms": 12345} -> single blob fallback

Source code in src/tnh_scholar/audio_processing/transcription/format_converter.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class FormatConverter:
    """
    Convert a raw transcription result to *text*, *SRT*, or (placeholder) *VTT*.

    The *raw* result must follow the loose schema
    - ``{"utterances": [...]}`` -> already speaker-segmented
    - ``{"words":       [...]}`` -> word-level; we chunk via :class:`SegmentBuilder`
    - ``{"text": "...", "audio_duration_ms": 12345}`` -> single blob fallback
    """

    def __init__(self, config: Optional[FormatConverterConfig] = None):
        self.config = config or FormatConverterConfig()
        self._segment_builder = TextSegmentBuilder(
            max_duration_ms=self.config.max_entry_duration_ms,
            target_characters=self.config.characters_per_entry,
            avoid_orphans=True,
            ignore_speaker=not self.config.include_speaker,  
            max_gap_duration_ms=self.config.max_gap_duration_ms
        )

    def convert(
        self,
        result: TranscriptionResult,
        format_type: str = "srt",
        format_options: Optional[Dict[str, Any]] = None,
    ) -> str:
        """
        Convert *result* to the given *format_type*.

        Parameters
        ----------
        result : dict
            Raw transcription output.
        format_type : {"srt", "text", "vtt"}
        format_options : dict | None
            Currently only ``{"include_speaker": bool}`` recognized for *srt*.
        """
        format_type = format_type.lower()
        format_options = format_options or {}

        timed_text = self._build_timed_text(result)

        if format_type == "text":
            return self._to_plain_text(timed_text)

        if format_type == "srt":
            include_speaker = format_options.get("include_speaker", True)
            processor = SRTProcessor()
            return processor.generate(timed_text, include_speaker=include_speaker)

        if format_type == "vtt":
            raise NotImplementedError("VTT conversion not implemented yet.")

        raise ValueError(f"Unsupported format_type: {format_type}")

    def _to_plain_text(self, timed_text: TimedText) -> str:
        """Flatten ``TimedText`` into a newline-separated block of text."""
        return "\n".join(unit.text for unit in timed_text.segments if unit.text)

    def _build_timed_text(self, result: TranscriptionResult) -> TimedText:
        """
        Normalize *result* into :class:`TimedText`, handling three cases:

        1. *utterance*-level input (already segmented)
        2. *word*-level input  - chunk via :class:`SegmentBuilder`
        3. plain *text* fallback
        """

        if timed_text := result.utterance_timing:
            units: List[TimedTextUnit] = []
            for i, unit in enumerate(timed_text.iter_segments(), start=1):
                data = unit.model_copy()

                units.append(
                    TimedTextUnit(
                        granularity=Granularity.SEGMENT,
                        text=data.text,
                        start_ms=data.start_ms,
                        end_ms=data.end_ms,
                        speaker=data.speaker,
                        index=i,
                        confidence=data.confidence,
                    )
                )

            return TimedText(segments=units, granularity=Granularity.SEGMENT)

        if words := result.word_timing:
            # *SegmentBuilder* returns a list[TimedTextUnit]
            return self._segment_builder.create_segments(words)

        if text := result.text:
            duration_ms = result.audio_duration_ms

            units = [
                TimedTextUnit(
                    granularity=Granularity.SEGMENT,
                    text=text,
                    start_ms=0,
                    end_ms=duration_ms or 0,
                    speaker=None,
                    index=None,
                    confidence=None,
                )
            ]
            return TimedText(segments=units, granularity=Granularity.SEGMENT)

        # If we arrived here – nothing to work with.
        raise ValueError(
            "Cannot build TimedText: result contains no utterances, words, or text."
        )
config = config or FormatConverterConfig() instance-attribute
__init__(config=None)
Source code in src/tnh_scholar/audio_processing/transcription/format_converter.py
59
60
61
62
63
64
65
66
67
def __init__(self, config: Optional[FormatConverterConfig] = None):
    self.config = config or FormatConverterConfig()
    self._segment_builder = TextSegmentBuilder(
        max_duration_ms=self.config.max_entry_duration_ms,
        target_characters=self.config.characters_per_entry,
        avoid_orphans=True,
        ignore_speaker=not self.config.include_speaker,  
        max_gap_duration_ms=self.config.max_gap_duration_ms
    )
convert(result, format_type='srt', format_options=None)

Convert result to the given format_type.

Parameters

result : dict Raw transcription output. format_type : {"srt", "text", "vtt"} format_options : dict | None Currently only {"include_speaker": bool} recognized for srt.

Source code in src/tnh_scholar/audio_processing/transcription/format_converter.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def convert(
    self,
    result: TranscriptionResult,
    format_type: str = "srt",
    format_options: Optional[Dict[str, Any]] = None,
) -> str:
    """
    Convert *result* to the given *format_type*.

    Parameters
    ----------
    result : dict
        Raw transcription output.
    format_type : {"srt", "text", "vtt"}
    format_options : dict | None
        Currently only ``{"include_speaker": bool}`` recognized for *srt*.
    """
    format_type = format_type.lower()
    format_options = format_options or {}

    timed_text = self._build_timed_text(result)

    if format_type == "text":
        return self._to_plain_text(timed_text)

    if format_type == "srt":
        include_speaker = format_options.get("include_speaker", True)
        processor = SRTProcessor()
        return processor.generate(timed_text, include_speaker=include_speaker)

    if format_type == "vtt":
        raise NotImplementedError("VTT conversion not implemented yet.")

    raise ValueError(f"Unsupported format_type: {format_type}")
FormatConverterConfig

Bases: BaseModel

User-tunable knobs for :class:FormatConverter.

Only a handful remain now that the heavy logic moved to SegmentBuilder.

Source code in src/tnh_scholar/audio_processing/transcription/format_converter.py
36
37
38
39
40
41
42
43
44
45
46
47
class FormatConverterConfig(BaseModel):
    """
    User-tunable knobs for :class:`FormatConverter`.

    Only a handful remain now that the heavy logic moved to `SegmentBuilder`.
    """

    max_entry_duration_ms: int = 6_000
    include_segment_index: bool = True
    include_speaker: bool = True
    characters_per_entry: int = 42
    max_gap_duration_ms: int = 2_000
characters_per_entry = 42 class-attribute instance-attribute
include_segment_index = True class-attribute instance-attribute
include_speaker = True class-attribute instance-attribute
max_entry_duration_ms = 6000 class-attribute instance-attribute
max_gap_duration_ms = 2000 class-attribute instance-attribute
patches
patch_file_with_name(file_obj, extension)

Ensures the file-like object has a .name attribute with the correct extension.

Source code in src/tnh_scholar/audio_processing/transcription/patches.py
10
11
12
13
14
15
def patch_file_with_name(file_obj: BytesIO, extension: str) -> BinaryIO:
    """
    Ensures the file-like object has a .name attribute with the correct extension.
    """
    file_obj.name = f"filename_placeholder.{extension}"
    return file_obj
patch_whisper_options(options, file_extension)

Patch routine to ensure 'file_extension' is present in transcription options dict. This is a workaround for OpenAI Whisper API, which requires file-like objects to have a filename/extension. Only allows known audio extensions.

Parameters:

Name Type Description Default
options Optional[Dict[str, Any]]

Transcription options dictionary (will not be mutated)

required
file_extension str

File extension string (with or without leading dot)

required

Returns:

Type Description
Dict[str, Any]

New options dictionary with 'file_extension' set appropriately

Raises:

Type Description
ValueError

If file_extension is not in the allowed list

Source code in src/tnh_scholar/audio_processing/transcription/patches.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def patch_whisper_options(
    options: Optional[Dict[str, Any]],
    file_extension: str
    ) -> Dict[str, Any]:
    """
    Patch routine to ensure 'file_extension' is present in transcription options dict.
    This is a workaround for OpenAI Whisper API, which requires file-like objects to have a
    filename/extension. Only allows known audio extensions.

    Args:
        options: Transcription options dictionary (will not be mutated)
        file_extension: File extension string (with or without leading dot)

    Returns:
        New options dictionary with 'file_extension' set appropriately

    Raises:
        ValueError: If file_extension is not in the allowed list
    """
    patched = dict(options) if options is not None else {}
    ext = file_extension.lstrip('.')
    if ext.lower() not in _ALLOWED_EXTENSIONS:
        raise ValueError(
            f"Unsupported file extension '{ext}'. Allowed extensions: {_ALLOWED_EXTENSIONS}"
        )
    patched['file_extension'] = ext
    return patched
srt_processor
SRTConfig

Configuration options for SRT processing.

Source code in src/tnh_scholar/audio_processing/transcription/srt_processor.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class SRTConfig:
    """Configuration options for SRT processing."""

    def __init__(
        self,
        include_speaker=False,
        speaker_format="[{speaker}] {text}",
        reindex_entries=True,
        timestamp_format="{:02d}:{:02d}:{:02d},{:03d}",
        max_chars_per_line=42,
        use_pysrt=False,
    ):
        """
        Initialize with default settings.

        Args:
            include_speaker: Whether to include speaker labels in output
            speaker_format: Format string for speaker attribution
            reindex_entries: Whether to reindex entries sequentially
            timestamp_format: Format string for timestamp formatting
            max_chars_per_line: Maximum characters per line before splitting
        """
        self.include_speaker = include_speaker
        self.speaker_format = speaker_format
        self.reindex_entries = reindex_entries
        self.timestamp_format = timestamp_format
        self.max_chars_per_line = max_chars_per_line
        self.use_pysrt = use_pysrt
include_speaker = include_speaker instance-attribute
max_chars_per_line = max_chars_per_line instance-attribute
reindex_entries = reindex_entries instance-attribute
speaker_format = speaker_format instance-attribute
timestamp_format = timestamp_format instance-attribute
use_pysrt = use_pysrt instance-attribute
__init__(include_speaker=False, speaker_format='[{speaker}] {text}', reindex_entries=True, timestamp_format='{:02d}:{:02d}:{:02d},{:03d}', max_chars_per_line=42, use_pysrt=False)

Initialize with default settings.

Parameters:

Name Type Description Default
include_speaker

Whether to include speaker labels in output

False
speaker_format

Format string for speaker attribution

'[{speaker}] {text}'
reindex_entries

Whether to reindex entries sequentially

True
timestamp_format

Format string for timestamp formatting

'{:02d}:{:02d}:{:02d},{:03d}'
max_chars_per_line

Maximum characters per line before splitting

42
Source code in src/tnh_scholar/audio_processing/transcription/srt_processor.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def __init__(
    self,
    include_speaker=False,
    speaker_format="[{speaker}] {text}",
    reindex_entries=True,
    timestamp_format="{:02d}:{:02d}:{:02d},{:03d}",
    max_chars_per_line=42,
    use_pysrt=False,
):
    """
    Initialize with default settings.

    Args:
        include_speaker: Whether to include speaker labels in output
        speaker_format: Format string for speaker attribution
        reindex_entries: Whether to reindex entries sequentially
        timestamp_format: Format string for timestamp formatting
        max_chars_per_line: Maximum characters per line before splitting
    """
    self.include_speaker = include_speaker
    self.speaker_format = speaker_format
    self.reindex_entries = reindex_entries
    self.timestamp_format = timestamp_format
    self.max_chars_per_line = max_chars_per_line
    self.use_pysrt = use_pysrt
SRTProcessor

Handles parsing and generating SRT format.

Provides functionality to convert between SRT text format and TimedText objects, with various formatting options. Supports both native parsing/generation and pysrt backend.

Source code in src/tnh_scholar/audio_processing/transcription/srt_processor.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
class SRTProcessor:
    """
    Handles parsing and generating SRT format.

    Provides functionality to convert between SRT text format and
    TimedText objects, with various formatting options.
    Supports both native parsing/generation and pysrt backend.
    """

    def __init__(self, config: Optional[SRTConfig] = None):
        """
        Initialize with optional configuration overrides.

        Args:
            config: Configuration options for SRT processing
        """
        self.config = config or SRTConfig()

    def merge_srts(self, srt_list: List[str]) -> str:
        """Merge multiple SRT files into a single SRT string."""
        timed_text_list = [self.parse(srt) for srt in srt_list]
        combined_timed_text = self.combine(timed_text_list)
        return self.generate(combined_timed_text, self.config.include_speaker)

    def generate(
        self, 
        timed_text: TimedText, 
        include_speaker: Optional[bool] = None
    ) -> str:
        """
        Generate SRT content from a TimedText object.
        Uses internal generator or pysrt depending on configuration.
        """
        if not include_speaker:
            include_speaker = self.config.include_speaker
        if self.config.use_pysrt:
            return self._generate_with_pysrt(timed_text)

        srt_parts = []
        srt_parts.extend(
            self._generate_entry(
                entry, 
                index=i if self.config.reindex_entries else entry.index,
                include_speaker=include_speaker
            )
            for i, entry in enumerate(timed_text.iter_segments(), start=1)
        )
        return "\n".join(srt_parts)

    def parse(self, srt_content: str) -> TimedText:
        """
        Parse SRT content into a new TimedText object.
        Uses internal parser or pysrt depending on configuration.
        """
        if self.config.use_pysrt:
            return self._parse_with_pysrt(srt_content)
        parser = self._SRTParser(srt_content)
        return parser.parse()

    def shift_timestamps(self, timed_text: TimedText, offset_ms: int) -> TimedText:
            """
            Shift all timestamps by the given offset.

            Args:
                timed_texts: List of TimedText objects
                offset_ms: Offset in milliseconds to apply

            Returns:
                New list of TimedText objects with adjusted timestamps
            """
            new_segments = [
                segment.shift_time(offset_ms) 
                for segment in timed_text.iter_segments()
                ]
            return TimedText(segments=new_segments)

    def combine(self, timed_texts: List[TimedText]) -> TimedText:
        """
        Combine multiple lists of TimedText into one, with proper indexing.

        Args:
            timed_text_lists: List of TimedText lists to combine

        Returns:
            Combined list of TimedText objects
        """
        combined_segments = []
        for timed_text in timed_texts:
            combined_segments.extend(timed_text.segments)

        # Sort by start time
        combined_segments.sort(key=lambda x: x.start_ms)

        return TimedText(segments=combined_segments)

    def assign_single_speaker(self, srt_content: str, speaker: str) -> str:
        """
        Assign the same speaker to all segments in the SRT content.
        """
        timed_text = self.parse(srt_content)
        timed_text.set_all_speakers(speaker)
        return self.generate(timed_text, include_speaker=True)

    def assign_speaker_by_mapping(
        self, srt_content: str, speaker_labels: dict[str, list[int]]
        ) -> str:
        """
        Assign speakers to segments based on a mapping of speaker to segment indices.
        (Not implemented yet.)
        """
        raise NotImplementedError("assign_speaker_by_mapping is not implemented yet.")

    def add_speaker_labels(
        self, 
        srt_content: str, 
        *, 
        speaker: Optional[str] = None, 
        speaker_labels: Optional[dict[str, list[int]]] = None
        ) -> str:
        """
        Unified entry point for adding speaker labels. 
        (Not implemented yet.)
        """
        raise NotImplementedError("add_speaker_labels is not implemented yet.")

    class _SRTParser:
        """Inner class to manage the state of the SRT parsing."""

        def __init__(self, srt_content: str):
            self.lines = srt_content.splitlines()
            self.current_index = 0
            self.timed_segments = []

        def parse(self) -> TimedText:
            while self.current_index < len(self.lines):
                if self.lines[self.current_index].strip():
                    try:
                        timed_segment = self._parse_entry()
                        self.timed_segments.append(timed_segment)
                    except (IndexError, ValueError) as e:
                        raise ValueError(
                            f"Invalid SRT format at line {self.current_index}: {e}"
                        ) from e
                self.current_index += 1  # Always increment to avoid infinite loops

            return TimedText(segments=self.timed_segments)

        def _parse_entry(self) -> TimedTextUnit:
            index = self._parse_index()
            start_time, end_time = self._parse_timestamps()
            text = self._parse_text()
            start_ms = self._timestamp_to_ms(start_time)
            end_ms = self._timestamp_to_ms(end_time)
            speaker, text = _extract_speaker_from_text(text)

            return TimedTextUnit(
                text=text,
                start_ms=start_ms,
                end_ms=end_ms,
                speaker=speaker,
                index=index,
                granularity=Granularity.SEGMENT,
                confidence=None,
            )

        def _parse_index(self) -> int:
            try:
                index = int(self.lines[self.current_index])
                self.current_index += 1
                return index
            except ValueError as ve:
                raise ValueError(
                    f"Invalid SRT entry index at line {self.current_index + 1}:"
                    f" '{self.lines[self.current_index]}' is not an integer."
                ) from ve

        def _parse_timestamps(self) -> Tuple[str, str]:
            timestamps_line = self.lines[self.current_index]
            start_time, end_time = timestamps_line.split("-->")
            self.current_index += 1
            return start_time.strip(), end_time.strip()

        def _parse_text(self) -> str:
            text_lines = []
            while self.current_index < len(self.lines) \
                and self.lines[self.current_index].strip():
                text_lines.append(self.lines[self.current_index])
                self.current_index += 1
            return "\n".join(text_lines).strip()

        def _timestamp_to_ms(self, timestamp: str) -> int:
            """
            Convert SRT timestamp (HH:MM:SS,mmm) to milliseconds.

            Args:
                timestamp: SRT format timestamp

            Returns:
                Timestamp in milliseconds
            """
            pattern = r"(\d{2}):(\d{2}):(\d{2}),(\d{3})"
            match = re.match(pattern, timestamp)
            if not match:
                raise ValueError(f"Invalid timestamp format: {timestamp}")

            hours, minutes, seconds, milliseconds = map(int, match.groups())
            return hours * 3600000 + minutes * 60000 + seconds * 1000 + milliseconds

    def _generate_entry(
        self, 
        entry: TimedTextUnit, 
        index: Optional[int] = None,
        include_speaker: bool = False,
        ) -> str:
        """Generate a single SRT entry from a TimedUnit."""
        start_timestamp = self._ms_to_timestamp(entry.start_ms)
        end_timestamp = self._ms_to_timestamp(entry.end_ms)

        text = entry.text
        if self.config.include_speaker and entry.speaker:
            text = self.config.speaker_format.format(speaker=entry.speaker, text=text)

        srt_entry = [
            str(index or 0),
            f"{start_timestamp} --> {end_timestamp}",
            text,
            "",  # Empty line between entries
        ]
        return "\n".join(srt_entry)

    def _ms_to_timestamp(self, milliseconds: int) -> str:
        """
        Convert milliseconds to SRT timestamp format (HH:MM:SS,mmm).

        Args:
            milliseconds: Time in milliseconds

        Returns:
            Formatted timestamp string
        """
        total_seconds, ms = divmod(milliseconds, 1000)
        hours, remainder = divmod(total_seconds, 3600)
        minutes, seconds = divmod(remainder, 60)

        return self.config.timestamp_format.format(hours, minutes, seconds, ms)

    def _parse_with_pysrt(self, srt_content: str) -> TimedText:
        """Internal: Parse using pysrt, extracting speaker information."""
        subs = pysrt.from_string(srt_content)
        segments = []
        for item in subs:
            speaker, text = _extract_speaker_from_text(item.text)
            segments.append(
                TimedTextUnit(
                    text=text,
                    speaker=speaker,
                    start_ms=item.start.ordinal,
                    end_ms=item.end.ordinal,
                    index=item.index,
                    granularity=Granularity.SEGMENT,
                    confidence=None,
                )
            )
        return TimedText(segments=segments)

    def _generate_with_pysrt(self, timed_text: TimedText) -> str:
        """Internal: Generate SRT using pysrt."""
        subs = pysrt.SubRipFile()
        for i, segment in enumerate(timed_text.iter_segments(), start=1):
            start = pysrt.SubRipTime(milliseconds=segment.start_ms)
            end = pysrt.SubRipTime(milliseconds=segment.end_ms)
            text = segment.text
            if self.config.include_speaker and segment.speaker:
                text = self.config.speaker_format.format(
                    speaker=segment.speaker, text=text)
            subs.append(pysrt.SubRipItem(index=i, start=start, end=end, text=text))
        return subs.to_string()
config = config or SRTConfig() instance-attribute
__init__(config=None)

Initialize with optional configuration overrides.

Parameters:

Name Type Description Default
config Optional[SRTConfig]

Configuration options for SRT processing

None
Source code in src/tnh_scholar/audio_processing/transcription/srt_processor.py
69
70
71
72
73
74
75
76
def __init__(self, config: Optional[SRTConfig] = None):
    """
    Initialize with optional configuration overrides.

    Args:
        config: Configuration options for SRT processing
    """
    self.config = config or SRTConfig()
add_speaker_labels(srt_content, *, speaker=None, speaker_labels=None)

Unified entry point for adding speaker labels. (Not implemented yet.)

Source code in src/tnh_scholar/audio_processing/transcription/srt_processor.py
172
173
174
175
176
177
178
179
180
181
182
183
def add_speaker_labels(
    self, 
    srt_content: str, 
    *, 
    speaker: Optional[str] = None, 
    speaker_labels: Optional[dict[str, list[int]]] = None
    ) -> str:
    """
    Unified entry point for adding speaker labels. 
    (Not implemented yet.)
    """
    raise NotImplementedError("add_speaker_labels is not implemented yet.")
assign_single_speaker(srt_content, speaker)

Assign the same speaker to all segments in the SRT content.

Source code in src/tnh_scholar/audio_processing/transcription/srt_processor.py
155
156
157
158
159
160
161
def assign_single_speaker(self, srt_content: str, speaker: str) -> str:
    """
    Assign the same speaker to all segments in the SRT content.
    """
    timed_text = self.parse(srt_content)
    timed_text.set_all_speakers(speaker)
    return self.generate(timed_text, include_speaker=True)
assign_speaker_by_mapping(srt_content, speaker_labels)

Assign speakers to segments based on a mapping of speaker to segment indices. (Not implemented yet.)

Source code in src/tnh_scholar/audio_processing/transcription/srt_processor.py
163
164
165
166
167
168
169
170
def assign_speaker_by_mapping(
    self, srt_content: str, speaker_labels: dict[str, list[int]]
    ) -> str:
    """
    Assign speakers to segments based on a mapping of speaker to segment indices.
    (Not implemented yet.)
    """
    raise NotImplementedError("assign_speaker_by_mapping is not implemented yet.")
combine(timed_texts)

Combine multiple lists of TimedText into one, with proper indexing.

Parameters:

Name Type Description Default
timed_text_lists

List of TimedText lists to combine

required

Returns:

Type Description
TimedText

Combined list of TimedText objects

Source code in src/tnh_scholar/audio_processing/transcription/srt_processor.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def combine(self, timed_texts: List[TimedText]) -> TimedText:
    """
    Combine multiple lists of TimedText into one, with proper indexing.

    Args:
        timed_text_lists: List of TimedText lists to combine

    Returns:
        Combined list of TimedText objects
    """
    combined_segments = []
    for timed_text in timed_texts:
        combined_segments.extend(timed_text.segments)

    # Sort by start time
    combined_segments.sort(key=lambda x: x.start_ms)

    return TimedText(segments=combined_segments)
generate(timed_text, include_speaker=None)

Generate SRT content from a TimedText object. Uses internal generator or pysrt depending on configuration.

Source code in src/tnh_scholar/audio_processing/transcription/srt_processor.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def generate(
    self, 
    timed_text: TimedText, 
    include_speaker: Optional[bool] = None
) -> str:
    """
    Generate SRT content from a TimedText object.
    Uses internal generator or pysrt depending on configuration.
    """
    if not include_speaker:
        include_speaker = self.config.include_speaker
    if self.config.use_pysrt:
        return self._generate_with_pysrt(timed_text)

    srt_parts = []
    srt_parts.extend(
        self._generate_entry(
            entry, 
            index=i if self.config.reindex_entries else entry.index,
            include_speaker=include_speaker
        )
        for i, entry in enumerate(timed_text.iter_segments(), start=1)
    )
    return "\n".join(srt_parts)
merge_srts(srt_list)

Merge multiple SRT files into a single SRT string.

Source code in src/tnh_scholar/audio_processing/transcription/srt_processor.py
78
79
80
81
82
def merge_srts(self, srt_list: List[str]) -> str:
    """Merge multiple SRT files into a single SRT string."""
    timed_text_list = [self.parse(srt) for srt in srt_list]
    combined_timed_text = self.combine(timed_text_list)
    return self.generate(combined_timed_text, self.config.include_speaker)
parse(srt_content)

Parse SRT content into a new TimedText object. Uses internal parser or pysrt depending on configuration.

Source code in src/tnh_scholar/audio_processing/transcription/srt_processor.py
109
110
111
112
113
114
115
116
117
def parse(self, srt_content: str) -> TimedText:
    """
    Parse SRT content into a new TimedText object.
    Uses internal parser or pysrt depending on configuration.
    """
    if self.config.use_pysrt:
        return self._parse_with_pysrt(srt_content)
    parser = self._SRTParser(srt_content)
    return parser.parse()
shift_timestamps(timed_text, offset_ms)

Shift all timestamps by the given offset.

Parameters:

Name Type Description Default
timed_texts

List of TimedText objects

required
offset_ms int

Offset in milliseconds to apply

required

Returns:

Type Description
TimedText

New list of TimedText objects with adjusted timestamps

Source code in src/tnh_scholar/audio_processing/transcription/srt_processor.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def shift_timestamps(self, timed_text: TimedText, offset_ms: int) -> TimedText:
        """
        Shift all timestamps by the given offset.

        Args:
            timed_texts: List of TimedText objects
            offset_ms: Offset in milliseconds to apply

        Returns:
            New list of TimedText objects with adjusted timestamps
        """
        new_segments = [
            segment.shift_time(offset_ms) 
            for segment in timed_text.iter_segments()
            ]
        return TimedText(segments=new_segments)
SubtitleFormat

Bases: str, Enum

Supported subtitle formats.

Source code in src/tnh_scholar/audio_processing/transcription/srt_processor.py
14
15
16
17
18
class SubtitleFormat(str, Enum):
    """Supported subtitle formats."""
    SRT = "srt"
    VTT = "vtt"
    TEXT = "text"
SRT = 'srt' class-attribute instance-attribute
TEXT = 'text' class-attribute instance-attribute
VTT = 'vtt' class-attribute instance-attribute
text_segment_builder

SegmentBuilder for creating phrase-level segments from word-level TimedText.

This module builds higher-level segments from a TimedText object containing word-level units, based on configurable criteria like duration, character count, punctuation, pauses, and speaker changes.

COMMON_ABBREVIATIONS = frozenset({'adj.', 'adm.', 'adv.', 'al.', 'anon.', 'apr.', 'arc.', 'aug.', 'ave.', 'brig.', 'bros.', 'capt.', 'cmdr.', 'col.', 'comdr.', 'con.', 'corp.', 'cpl.', 'dr.', 'drs.', 'ed.', 'enc.', 'etc.', 'ex.', 'feb.', 'gen.', 'gov.', 'hon.', 'hosp.', 'hr.', 'inc.', 'jan.', 'jr.', 'maj.', 'mar.', 'messrs.', 'mlle.', 'mm.', 'mme.', 'mr.', 'mrs.', 'ms.', 'msgr.', 'nov.', 'oct.', 'op.', 'ord.', 'ph.d.', 'prof.', 'pvt.', 'rep.', 'reps.', 'res.', 'rev.', 'rt.', 'sen.', 'sens.', 'sep.', 'sfc.', 'sgt.', 'sr.', 'st.', 'supt.', 'surg.', 'u.s.', 'v.p.', 'vs.'}) module-attribute
TextSegmentBuilder
Source code in src/tnh_scholar/audio_processing/transcription/text_segment_builder.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class TextSegmentBuilder:
    def __init__(
        self,
        *,
        max_duration_ms: Optional[int] = None, # milliseconds
        target_characters: Optional[int] = None,
        avoid_orphans: bool = True,
        max_gap_duration_ms: Optional[int] = None, # milliseconds
        ignore_speaker: bool = True,
    ):
        self.max_duration = max_duration_ms
        self.target_characters = target_characters
        self.avoid_orphans = avoid_orphans
        self.max_gap_duration = max_gap_duration_ms
        self.ignore_speaker = ignore_speaker

        self.segments: List[TimedTextUnit] = []
        self.current_words: List[TimedTextUnit] = []
        self.current_characters = 0

    def create_segments(self, timed_text: TimedText) -> TimedText:
        # Validate
        if not timed_text.words:
            raise ValueError(
                "TimedText object must have word-level units to build segments."
                )

        for unit in timed_text.words:
            if unit.granularity != Granularity.WORD:
                raise ValueError(f"Expected WORD units, got {unit.granularity}")

        # Process
        for word in timed_text.words:
            if self._should_start_new_segment(word):
                self._flush_current_words()
            self._add_word(word)

        self._flush_current_words()  # Final flush
        return TimedText(segments=self.segments, granularity=Granularity.SEGMENT)

    def _add_word(self, word: TimedTextUnit):
        if self.current_words:
            self.current_characters += 1  # space before the new word
        self.current_characters += len(word.text)
        self.current_words.append(word)


    def _should_start_new_segment(self, word: TimedTextUnit) -> bool:
        if not self.current_words:
            return False

        # Speaker change
        last_word = self.current_words[-1]
        if not self.ignore_speaker and (word.speaker != last_word.speaker):
            return True

        # Significant pause
        if self.max_gap_duration is not None:
            pause = word.start_ms - last_word.end_ms
            if pause > self.max_gap_duration:
                return True

        # End punctuation
        if last_word.text and self._is_punctuation_word(last_word.text):
            return True

        # Max duration
        if self.max_duration is not None:
            duration = word.end_ms - self.current_words[0].start_ms
            if duration > self.max_duration:
                return True

        # Target character count
        if self.target_characters is not None:
            total_chars = self.current_characters + len(word.text) + 1
            if total_chars > self.target_characters:
                return True

        return False

    def _flush_current_words(self):
        if not self.current_words:
            return

        segment_text = " ".join(word.text for word in self.current_words)
        segment = TimedTextUnit(
            text=segment_text,
            start_ms=self.current_words[0].start_ms,
            end_ms=self.current_words[-1].end_ms,
            granularity=Granularity.SEGMENT,
            speaker=None if self.ignore_speaker else self._find_speaker(),
            confidence=None,
            index=None,
        )
        self.segments.append(segment)
        self.current_words = []
        self.current_characters = 0

    def _find_speaker(self) -> Optional[str]:
        """
        Only called when ignore_speakers is False; 
        in this case we always split on speaker. 
        So only one speaker is expected. 
        """
        speakers = {word.speaker for word in self.current_words}
        assert len(speakers) == 1, "Inconsistent speakers in segment"
        return speakers.pop()

    def _is_punctuation_word(self, word_text: str) -> bool:
        """
        Check if a word ending in punctuation should trigger a new segment,
        excluding common abbreviations.
        """
        if not word_text:
            return False
        return word_text[-1] in ".!?" and word_text.lower() not in COMMON_ABBREVIATIONS


    def build_segments(
        self,
        *,
        target_duration: Optional[int] = None,
        target_characters: Optional[int] = None,
        avoid_orphans: Optional[bool] = True,
        max_gap_duration: Optional[int] = None,
        ignore_speaker: bool = False,
    ) -> None:
        """
        Build or rebuild `segments` from the contents of `words`.

        Args:
            target_duration: Maximum desired segment duration in milliseconds.
            target_characters: Maximum desired character length of a segment.
            speaker_split: Whether to start a new segment when the speaker changes.

        Note:
            This is a stub.  Concrete algorithms will be implemented later.

        Raises:
            NotImplementedError: Always, until implemented.
        """
        raise NotImplementedError("build_segments is not yet implemented.")
avoid_orphans = avoid_orphans instance-attribute
current_characters = 0 instance-attribute
current_words = [] instance-attribute
ignore_speaker = ignore_speaker instance-attribute
max_duration = max_duration_ms instance-attribute
max_gap_duration = max_gap_duration_ms instance-attribute
segments = [] instance-attribute
target_characters = target_characters instance-attribute
__init__(*, max_duration_ms=None, target_characters=None, avoid_orphans=True, max_gap_duration_ms=None, ignore_speaker=True)
Source code in src/tnh_scholar/audio_processing/transcription/text_segment_builder.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def __init__(
    self,
    *,
    max_duration_ms: Optional[int] = None, # milliseconds
    target_characters: Optional[int] = None,
    avoid_orphans: bool = True,
    max_gap_duration_ms: Optional[int] = None, # milliseconds
    ignore_speaker: bool = True,
):
    self.max_duration = max_duration_ms
    self.target_characters = target_characters
    self.avoid_orphans = avoid_orphans
    self.max_gap_duration = max_gap_duration_ms
    self.ignore_speaker = ignore_speaker

    self.segments: List[TimedTextUnit] = []
    self.current_words: List[TimedTextUnit] = []
    self.current_characters = 0
build_segments(*, target_duration=None, target_characters=None, avoid_orphans=True, max_gap_duration=None, ignore_speaker=False)

Build or rebuild segments from the contents of words.

Parameters:

Name Type Description Default
target_duration Optional[int]

Maximum desired segment duration in milliseconds.

None
target_characters Optional[int]

Maximum desired character length of a segment.

None
speaker_split

Whether to start a new segment when the speaker changes.

required
Note

This is a stub. Concrete algorithms will be implemented later.

Raises:

Type Description
NotImplementedError

Always, until implemented.

Source code in src/tnh_scholar/audio_processing/transcription/text_segment_builder.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def build_segments(
    self,
    *,
    target_duration: Optional[int] = None,
    target_characters: Optional[int] = None,
    avoid_orphans: Optional[bool] = True,
    max_gap_duration: Optional[int] = None,
    ignore_speaker: bool = False,
) -> None:
    """
    Build or rebuild `segments` from the contents of `words`.

    Args:
        target_duration: Maximum desired segment duration in milliseconds.
        target_characters: Maximum desired character length of a segment.
        speaker_split: Whether to start a new segment when the speaker changes.

    Note:
        This is a stub.  Concrete algorithms will be implemented later.

    Raises:
        NotImplementedError: Always, until implemented.
    """
    raise NotImplementedError("build_segments is not yet implemented.")
create_segments(timed_text)
Source code in src/tnh_scholar/audio_processing/transcription/text_segment_builder.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def create_segments(self, timed_text: TimedText) -> TimedText:
    # Validate
    if not timed_text.words:
        raise ValueError(
            "TimedText object must have word-level units to build segments."
            )

    for unit in timed_text.words:
        if unit.granularity != Granularity.WORD:
            raise ValueError(f"Expected WORD units, got {unit.granularity}")

    # Process
    for word in timed_text.words:
        if self._should_start_new_segment(word):
            self._flush_current_words()
        self._add_word(word)

    self._flush_current_words()  # Final flush
    return TimedText(segments=self.segments, granularity=Granularity.SEGMENT)
transcription_service
TranscriptionResult

Bases: BaseModel

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
26
27
28
29
30
31
32
33
34
35
class TranscriptionResult(BaseModel):
    text: str
    language: str
    word_timing: Optional[TimedText] = None
    utterance_timing: Optional[TimedText] = None
    confidence: Optional[float] = None
    audio_duration_ms: Optional[int] = None
    transcript_id: Optional[str] = None
    status: Optional[str] = None
    raw_result: Optional[Dict[str, Any]] = None
audio_duration_ms = None class-attribute instance-attribute
confidence = None class-attribute instance-attribute
language instance-attribute
raw_result = None class-attribute instance-attribute
status = None class-attribute instance-attribute
text instance-attribute
transcript_id = None class-attribute instance-attribute
utterance_timing = None class-attribute instance-attribute
word_timing = None class-attribute instance-attribute
TranscriptionService

Bases: ABC

Abstract base class defining the interface for transcription services.

This interface provides a standard way to interact with different transcription service providers (e.g., OpenAI Whisper, AssemblyAI).

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class TranscriptionService(ABC):
    """
    Abstract base class defining the interface for transcription services.

    This interface provides a standard way to interact with different
    transcription service providers (e.g., OpenAI Whisper, AssemblyAI).
    """

    @abstractmethod
    def transcribe(
        self,
        audio_file: Union[Path, BytesIO],
        options: Optional[Dict[str, Any]] = None
    ) -> TranscriptionResult:
        """
        Transcribe audio file to text.

        Args:
            audio_file: Path to audio file or file-like object
            options: Provider-specific options for transcription

        Returns:
            TranscriptionResult

        """
        pass

    @abstractmethod
    def get_result(self, job_id: str) -> TranscriptionResult:
        """
        Get results for an existing transcription job.

        Args:
            job_id: ID of the transcription job

        Returns:
            Dictionary containing transcription results in the same
            standardized format as transcribe()
        """
        pass

    @abstractmethod
    def transcribe_to_format(
        self,
        audio_file: Union[Path, BytesIO],
        format_type: str = "srt",
        transcription_options: Optional[Dict[str, Any]] = None,
        format_options: Optional[Dict[str, Any]] = None
    ) -> str:
        """
        Transcribe audio and return result in specified format.

        Args:
            audio_file: Path, file-like object, or URL of audio file
            format_type: Format type (e.g., "srt", "vtt", "text")
            transcription_options: Options for transcription
            format_options: Format-specific options

        Returns:
            String representation in the requested format
        """
        pass
get_result(job_id) abstractmethod

Get results for an existing transcription job.

Parameters:

Name Type Description Default
job_id str

ID of the transcription job

required

Returns:

Type Description
TranscriptionResult

Dictionary containing transcription results in the same

TranscriptionResult

standardized format as transcribe()

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
64
65
66
67
68
69
70
71
72
73
74
75
76
@abstractmethod
def get_result(self, job_id: str) -> TranscriptionResult:
    """
    Get results for an existing transcription job.

    Args:
        job_id: ID of the transcription job

    Returns:
        Dictionary containing transcription results in the same
        standardized format as transcribe()
    """
    pass
transcribe(audio_file, options=None) abstractmethod

Transcribe audio file to text.

Parameters:

Name Type Description Default
audio_file Union[Path, BytesIO]

Path to audio file or file-like object

required
options Optional[Dict[str, Any]]

Provider-specific options for transcription

None

Returns:

Type Description
TranscriptionResult

TranscriptionResult

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@abstractmethod
def transcribe(
    self,
    audio_file: Union[Path, BytesIO],
    options: Optional[Dict[str, Any]] = None
) -> TranscriptionResult:
    """
    Transcribe audio file to text.

    Args:
        audio_file: Path to audio file or file-like object
        options: Provider-specific options for transcription

    Returns:
        TranscriptionResult

    """
    pass
transcribe_to_format(audio_file, format_type='srt', transcription_options=None, format_options=None) abstractmethod

Transcribe audio and return result in specified format.

Parameters:

Name Type Description Default
audio_file Union[Path, BytesIO]

Path, file-like object, or URL of audio file

required
format_type str

Format type (e.g., "srt", "vtt", "text")

'srt'
transcription_options Optional[Dict[str, Any]]

Options for transcription

None
format_options Optional[Dict[str, Any]]

Format-specific options

None

Returns:

Type Description
str

String representation in the requested format

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@abstractmethod
def transcribe_to_format(
    self,
    audio_file: Union[Path, BytesIO],
    format_type: str = "srt",
    transcription_options: Optional[Dict[str, Any]] = None,
    format_options: Optional[Dict[str, Any]] = None
) -> str:
    """
    Transcribe audio and return result in specified format.

    Args:
        audio_file: Path, file-like object, or URL of audio file
        format_type: Format type (e.g., "srt", "vtt", "text")
        transcription_options: Options for transcription
        format_options: Format-specific options

    Returns:
        String representation in the requested format
    """
    pass
TranscriptionServiceFactory

Factory for creating transcription service instances.

This factory provides a standard way to create transcription service instances based on the provider name and configuration.

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
class TranscriptionServiceFactory:
    """
    Factory for creating transcription service instances.

    This factory provides a standard way to create transcription service
    instances based on the provider name and configuration.
    """

    # Mapping provider names to implementation classes
    # Classes will be imported lazily when needed
    _PROVIDER_MAP: Dict[str, Callable[..., TranscriptionService]] = {}

    @classmethod
    def register_provider(
        cls, 
        name: str, 
        provider_class: Callable[..., TranscriptionService]
    ) -> None:
        """
        Register a provider implementation with the factory.

        Args:
            name: Provider name (lowercase)
            provider_class: Provider implementation class or factory function

        Example:
            >>> from my_module import MyTranscriptionService
            >>> TranscriptionServiceFactory.register_provider("my_provider", MyTranscriptionService)
        """  
        cls._PROVIDER_MAP[name.lower()] = provider_class

    @classmethod
    def create_service(
        cls,
        provider: str = "assemblyai",
        api_key: Optional[str] = None,
        **kwargs
    ) -> TranscriptionService:
        """
        Create a transcription service instance.

        Args:
            provider: Service provider name (e.g., "whisper", "assemblyai")
            api_key: API key for the service
            **kwargs: Additional provider-specific configuration

        Returns:
            TranscriptionService instance

        Raises:
            ValueError: If the provider is not supported
            ImportError: If the provider module cannot be imported
        """
        provider = provider.lower()

        # Initialize provider map if empty
        if not cls._PROVIDER_MAP:
            # Import lazily to avoid circular imports
            from .assemblyai_service import AAITranscriptionService
            from .whisper_service import WhisperTranscriptionService

            cls._PROVIDER_MAP = {
                "whisper": WhisperTranscriptionService,
                "assemblyai": AAITranscriptionService,
            }

        # Get the provider implementation
        provider_class = cls._PROVIDER_MAP.get(provider)

        if provider_class is None:
            raise ValueError(f"Unsupported transcription provider: {provider}")

        # Create and return the service instance
        return provider_class(api_key=api_key, **kwargs)
create_service(provider='assemblyai', api_key=None, **kwargs) classmethod

Create a transcription service instance.

Parameters:

Name Type Description Default
provider str

Service provider name (e.g., "whisper", "assemblyai")

'assemblyai'
api_key Optional[str]

API key for the service

None
**kwargs

Additional provider-specific configuration

{}

Returns:

Type Description
TranscriptionService

TranscriptionService instance

Raises:

Type Description
ValueError

If the provider is not supported

ImportError

If the provider module cannot be imported

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
@classmethod
def create_service(
    cls,
    provider: str = "assemblyai",
    api_key: Optional[str] = None,
    **kwargs
) -> TranscriptionService:
    """
    Create a transcription service instance.

    Args:
        provider: Service provider name (e.g., "whisper", "assemblyai")
        api_key: API key for the service
        **kwargs: Additional provider-specific configuration

    Returns:
        TranscriptionService instance

    Raises:
        ValueError: If the provider is not supported
        ImportError: If the provider module cannot be imported
    """
    provider = provider.lower()

    # Initialize provider map if empty
    if not cls._PROVIDER_MAP:
        # Import lazily to avoid circular imports
        from .assemblyai_service import AAITranscriptionService
        from .whisper_service import WhisperTranscriptionService

        cls._PROVIDER_MAP = {
            "whisper": WhisperTranscriptionService,
            "assemblyai": AAITranscriptionService,
        }

    # Get the provider implementation
    provider_class = cls._PROVIDER_MAP.get(provider)

    if provider_class is None:
        raise ValueError(f"Unsupported transcription provider: {provider}")

    # Create and return the service instance
    return provider_class(api_key=api_key, **kwargs)
register_provider(name, provider_class) classmethod

Register a provider implementation with the factory.

Parameters:

Name Type Description Default
name str

Provider name (lowercase)

required
provider_class Callable[..., TranscriptionService]

Provider implementation class or factory function

required
Example

from my_module import MyTranscriptionService TranscriptionServiceFactory.register_provider("my_provider", MyTranscriptionService)

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
@classmethod
def register_provider(
    cls, 
    name: str, 
    provider_class: Callable[..., TranscriptionService]
) -> None:
    """
    Register a provider implementation with the factory.

    Args:
        name: Provider name (lowercase)
        provider_class: Provider implementation class or factory function

    Example:
        >>> from my_module import MyTranscriptionService
        >>> TranscriptionServiceFactory.register_provider("my_provider", MyTranscriptionService)
    """  
    cls._PROVIDER_MAP[name.lower()] = provider_class
Utterance

Bases: BaseModel

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
19
20
21
22
23
24
class Utterance(BaseModel):
    speaker: Optional[str]
    start_ms: int
    end_ms: int
    text: str
    confidence: float
confidence instance-attribute
end_ms instance-attribute
speaker instance-attribute
start_ms instance-attribute
text instance-attribute
WordTiming

Bases: BaseModel

Source code in src/tnh_scholar/audio_processing/transcription/transcription_service.py
13
14
15
16
17
class WordTiming(BaseModel):
    word: str
    start_ms: int
    end_ms: int
    confidence: float
confidence instance-attribute
end_ms instance-attribute
start_ms instance-attribute
word instance-attribute
vtt_processor
VTTConfig

Configuration options for WebVTT processing.

Source code in src/tnh_scholar/audio_processing/transcription/vtt_processor.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class VTTConfig:
    """Configuration options for WebVTT processing."""

    def __init__(
        self,
        include_speaker=False,
        speaker_format="<v {speaker}>{text}",
        reindex_entries=False,
        timestamp_format="{:02d}:{:02d}:{:02d}.{:03d}",
        max_chars_per_line=42
    ):
        """
        Initialize with default settings.

        Args:
            include_speaker: Whether to include speaker labels in output
            speaker_format: Format string for speaker attribution
            reindex_entries: Whether to reindex entries sequentially
            timestamp_format: Format string for timestamp formatting
            max_chars_per_line: Maximum characters per line before splitting
        """
        self.include_speaker = include_speaker
        self.speaker_format = speaker_format
        self.reindex_entries = reindex_entries
        self.timestamp_format = timestamp_format
        self.max_chars_per_line = max_chars_per_line
include_speaker = include_speaker instance-attribute
max_chars_per_line = max_chars_per_line instance-attribute
reindex_entries = reindex_entries instance-attribute
speaker_format = speaker_format instance-attribute
timestamp_format = timestamp_format instance-attribute
__init__(include_speaker=False, speaker_format='<v {speaker}>{text}', reindex_entries=False, timestamp_format='{:02d}:{:02d}:{:02d}.{:03d}', max_chars_per_line=42)

Initialize with default settings.

Parameters:

Name Type Description Default
include_speaker

Whether to include speaker labels in output

False
speaker_format

Format string for speaker attribution

'<v {speaker}>{text}'
reindex_entries

Whether to reindex entries sequentially

False
timestamp_format

Format string for timestamp formatting

'{:02d}:{:02d}:{:02d}.{:03d}'
max_chars_per_line

Maximum characters per line before splitting

42
Source code in src/tnh_scholar/audio_processing/transcription/vtt_processor.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(
    self,
    include_speaker=False,
    speaker_format="<v {speaker}>{text}",
    reindex_entries=False,
    timestamp_format="{:02d}:{:02d}:{:02d}.{:03d}",
    max_chars_per_line=42
):
    """
    Initialize with default settings.

    Args:
        include_speaker: Whether to include speaker labels in output
        speaker_format: Format string for speaker attribution
        reindex_entries: Whether to reindex entries sequentially
        timestamp_format: Format string for timestamp formatting
        max_chars_per_line: Maximum characters per line before splitting
    """
    self.include_speaker = include_speaker
    self.speaker_format = speaker_format
    self.reindex_entries = reindex_entries
    self.timestamp_format = timestamp_format
    self.max_chars_per_line = max_chars_per_line
VTTProcessor

Handles parsing and generating WebVTT format.

Source code in src/tnh_scholar/audio_processing/transcription/vtt_processor.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class VTTProcessor:
    """Handles parsing and generating WebVTT format."""

    def __init__(self, config: Optional[VTTConfig] = None):
        """
        Initialize with optional configuration.

        Args:
            config: Configuration options for VTT processing
        """
        self.config = config or VTTConfig()

    def parse(self, vtt_content: str) -> List[TimedTextUnit]:
        """
        Parse VTT content into a list of TimedUnit objects.

        Args:
            vtt_content: String containing VTT formatted content

        Returns:
            List of TimedUnit objects
        """
        # Implementation will go here
        raise NotImplementedError("Not implemented.")

    def generate(self, timed_texts: List[TimedTextUnit]) -> str:
        """
        Generate VTT content from a list of TimedUnit objects.

        Args:
            timed_texts: List of TimedUnit objects

        Returns:
            String containing VTT formatted content
        """
        # Implementation will go here
        raise NotImplementedError("Not implemented.")
config = config or VTTConfig() instance-attribute
__init__(config=None)

Initialize with optional configuration.

Parameters:

Name Type Description Default
config Optional[VTTConfig]

Configuration options for VTT processing

None
Source code in src/tnh_scholar/audio_processing/transcription/vtt_processor.py
41
42
43
44
45
46
47
48
def __init__(self, config: Optional[VTTConfig] = None):
    """
    Initialize with optional configuration.

    Args:
        config: Configuration options for VTT processing
    """
    self.config = config or VTTConfig()
generate(timed_texts)

Generate VTT content from a list of TimedUnit objects.

Parameters:

Name Type Description Default
timed_texts List[TimedTextUnit]

List of TimedUnit objects

required

Returns:

Type Description
str

String containing VTT formatted content

Source code in src/tnh_scholar/audio_processing/transcription/vtt_processor.py
63
64
65
66
67
68
69
70
71
72
73
74
def generate(self, timed_texts: List[TimedTextUnit]) -> str:
    """
    Generate VTT content from a list of TimedUnit objects.

    Args:
        timed_texts: List of TimedUnit objects

    Returns:
        String containing VTT formatted content
    """
    # Implementation will go here
    raise NotImplementedError("Not implemented.")
parse(vtt_content)

Parse VTT content into a list of TimedUnit objects.

Parameters:

Name Type Description Default
vtt_content str

String containing VTT formatted content

required

Returns:

Type Description
List[TimedTextUnit]

List of TimedUnit objects

Source code in src/tnh_scholar/audio_processing/transcription/vtt_processor.py
50
51
52
53
54
55
56
57
58
59
60
61
def parse(self, vtt_content: str) -> List[TimedTextUnit]:
    """
    Parse VTT content into a list of TimedUnit objects.

    Args:
        vtt_content: String containing VTT formatted content

    Returns:
        List of TimedUnit objects
    """
    # Implementation will go here
    raise NotImplementedError("Not implemented.")
whisper_service
TODO: MAJOR REFACTOR PLANNED

This module currently mixes persistent service configuration (WhisperConfig) with per-call runtime options, leading to complex validation and logic. Plan is to:

  • Refactor so each WhisperTranscriptionService instance is configured once at construction, with all relevant settings (including file-like/path-like mode, file extension, etc).
  • Use Pydantic BaseSettings for configuration to normalize configuration and validation according to TNH Scholar style.
  • Remove ad-hoc runtime options from the transcribe() entrypoint; all config should be set at init.
  • If a different configuration is needed, instantiate a new service object.
  • This will simplify validation, error handling, and code logic, and make the contract clear and robust.
  • NOTE: This will change the TranscriptionService contract and will require similar changes in other transcription system implementations.
  • Update all dependent code and tests accordingly.

logger = get_child_logger(__name__) module-attribute
WhisperBase

Bases: TypedDict

Source code in src/tnh_scholar/audio_processing/transcription/whisper_service.py
78
79
80
81
class WhisperBase(TypedDict):
    text: str
    language: str
    duration: float
duration instance-attribute
language instance-attribute
text instance-attribute
WhisperConfig dataclass

Configuration for the Whisper transcription service.

Source code in src/tnh_scholar/audio_processing/transcription/whisper_service.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@dataclass
class WhisperConfig:
    """Configuration for the Whisper transcription service."""
    model: str = "whisper-1"
    response_format: str = "verbose_json"
    timestamp_granularities: Optional[List[str]] = field(
        default_factory=lambda: ["word"]
        )
    chunking_strategy: str = "auto" # currently not usable
    language: Optional[str] = None # language code
    temperature: Optional[float] = None
    prompt: Optional[str] = None

    # Supported response formats
    SUPPORTED_FORMATS = ["json", "text", "srt", "vtt", "verbose_json"]

    # Parameters allowed for each format
    FORMAT_PARAMS = {
        "verbose_json": ["timestamp_granularities"],
        "json": [],
        "text": [],
        "srt": [],
        "vtt": []
    }

    # Basic parameters: always allowed
    BASE_PARAMS = [
            "model", "language", "temperature", "prompt", "response_format",
        ]


    def to_dict(self) -> Dict[str, Any]:
        """Convert configuration to dictionary for API call."""
        # Filter out None values to avoid sending unnecessary parameters
        return {k: v for k, v in self.__dict__.items() 
                if v is not None and not k.startswith("_") and k != "SUPPORTED_FORMATS"}

    def validate(self) -> None:
        """Validate configuration values."""
        if self.response_format not in self.SUPPORTED_FORMATS:
            logger.warning(
                f"Unsupported response format: {self.response_format}, "
                f"defaulting to 'verbose_json'"
            )
            self.response_format = "verbose_json"
BASE_PARAMS = ['model', 'language', 'temperature', 'prompt', 'response_format'] class-attribute instance-attribute
FORMAT_PARAMS = {'verbose_json': ['timestamp_granularities'], 'json': [], 'text': [], 'srt': [], 'vtt': []} class-attribute instance-attribute
SUPPORTED_FORMATS = ['json', 'text', 'srt', 'vtt', 'verbose_json'] class-attribute instance-attribute
chunking_strategy = 'auto' class-attribute instance-attribute
language = None class-attribute instance-attribute
model = 'whisper-1' class-attribute instance-attribute
prompt = None class-attribute instance-attribute
response_format = 'verbose_json' class-attribute instance-attribute
temperature = None class-attribute instance-attribute
timestamp_granularities = field(default_factory=(lambda: ['word'])) class-attribute instance-attribute
__init__(model='whisper-1', response_format='verbose_json', timestamp_granularities=(lambda: ['word'])(), chunking_strategy='auto', language=None, temperature=None, prompt=None)
to_dict()

Convert configuration to dictionary for API call.

Source code in src/tnh_scholar/audio_processing/transcription/whisper_service.py
120
121
122
123
124
def to_dict(self) -> Dict[str, Any]:
    """Convert configuration to dictionary for API call."""
    # Filter out None values to avoid sending unnecessary parameters
    return {k: v for k, v in self.__dict__.items() 
            if v is not None and not k.startswith("_") and k != "SUPPORTED_FORMATS"}
validate()

Validate configuration values.

Source code in src/tnh_scholar/audio_processing/transcription/whisper_service.py
126
127
128
129
130
131
132
133
def validate(self) -> None:
    """Validate configuration values."""
    if self.response_format not in self.SUPPORTED_FORMATS:
        logger.warning(
            f"Unsupported response format: {self.response_format}, "
            f"defaulting to 'verbose_json'"
        )
        self.response_format = "verbose_json"
WhisperResponse

Bases: WhisperBase

Source code in src/tnh_scholar/audio_processing/transcription/whisper_service.py
84
85
86
class WhisperResponse(WhisperBase, total=False):
    words: Optional[List[WordEntry]]
    segments: Optional[List[WhisperSegment]]
segments instance-attribute
words instance-attribute
WhisperSegment

Bases: TypedDict

Source code in src/tnh_scholar/audio_processing/transcription/whisper_service.py
67
68
69
70
71
72
73
74
75
class WhisperSegment(TypedDict, total=False):
    id: int
    start: float
    end: float
    text: str
    temperature: float
    avg_logprob: float
    compression_ratio: float
    no_speech_prob: float
avg_logprob instance-attribute
compression_ratio instance-attribute
end instance-attribute
id instance-attribute
no_speech_prob instance-attribute
start instance-attribute
temperature instance-attribute
text instance-attribute
WhisperTranscriptionService

Bases: TranscriptionService

OpenAI Whisper implementation of the TranscriptionService interface.

Provides transcription services using the OpenAI Whisper API.

Source code in src/tnh_scholar/audio_processing/transcription/whisper_service.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
class WhisperTranscriptionService(TranscriptionService):
    """
    OpenAI Whisper implementation of the TranscriptionService interface.

    Provides transcription services using the OpenAI Whisper API.
    """

    def __init__(self, api_key: Optional[str] = None, **config_options):
        """
        Initialize the Whisper transcription service.

        Args:
            api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
            **config_options: Additional configuration options
        """
        # Create configuration base
        self.config = WhisperConfig()

        # Set any configuration options provided
        for key, value in config_options.items():
            if hasattr(self.config, key):
                setattr(self.config, key, value)

        # Validate configuration
        self.config.validate()

        # Initialize format converter
        self.format_converter = FormatConverter()

        # Set API key
        self.set_api_key(api_key)

    def _create_jsonl_writer(self):
        """
        Create a file-like object that captures JSONL output.

        Returns:
            A file-like object that captures writes
        """
        class JsonlCapture:
            def __init__(self):
                self.data = []

            def write(self, content):
                try:
                    # Try to parse as JSON
                    json_obj = json.loads(content)
                    self.data.append(json_obj)
                except json.JSONDecodeError:
                    # If not valid JSON, just append as string
                    self.data.append(content)

            def flush(self):
                pass

            def close(self):
                pass

        return JsonlCapture()

    def _prepare_file_object(
        self,
        audio_file: Union[Path, BytesIO],
        options: Optional[Dict[str, Any]] = None
    ) -> tuple[BinaryIO, bool]:
        """
        Prepare file object for API call. PATCH: file-like objects require 'file_extension' in options.

        Args:
            audio_file: Path to audio file or file-like object
            options: Dict containing 'file_extension' if audio_file is file-like

        Returns:
            Tuple of (file_object, should_close_file)

        Raises:
            ValueError: If file-like object is provided without 'file_extension' in options
        """
        if isinstance(audio_file, Path):
            try:
                file_obj = open(audio_file, "rb")
                should_close = True
            except (IOError, OSError) as e:
                raise RuntimeError(f"Failed to open audio file '{audio_file}': {e}") from e
        else:
            file_extension = options.get("file_extension", None) if options else None
            if not file_extension:
                logger.error(f"No file extension provided in options for file-like object: {audio_file}")
                raise ValueError(
                    "For file-like objects, options['file_extension'] "
                    "must be provided. (PATCH for OpenAI API which requires)."
                )
            file_obj = patch_file_with_name(audio_file, file_extension)
            should_close = False

        return file_obj, should_close

    def _prepare_api_params(
        self, options: Optional[Dict[str, Any]] = None
        ) -> Dict[str, Any]:
        """
        Prepare parameters for the Whisper API call.

        Args:
            options: Additional options for this specific transcription

        Returns:
            Dictionary of parameters for the API call
        """
        base_params = self.config.to_dict()

        # Defensive: ensure options is a dict
        options = options or {}

        # Determine which response format we're using
        response_format = options.get(
            "response_format", self.config.response_format
        )

        # Compute allowed parameters for the chosen format
        allowed_params = (
            set(self.config.FORMAT_PARAMS.get(response_format, []))
            | 
            set(self.config.BASE_PARAMS)
        )

        # Start with base params, filtered to allowed
        api_params = {k: v for k, v in base_params.items() if k in allowed_params}

        # Overlay with options, but only allowed keys
        for k, v in options.items():
            if k in allowed_params:
                api_params[k] = v

        # Validate response format
        if api_params.get("response_format") not in self.config.SUPPORTED_FORMATS:
            logger.warning(
                f"Unsupported response format: {api_params.get('response_format')}, "
                f"defaulting to 'verbose_json'"
            )
            api_params["response_format"] = "verbose_json"

        return api_params

    def _to_whisper_response(self, response: Any) -> WhisperResponse:
        """
        Convert an OpenAI Whisper API response (JSON or Verbose JSON) into a clean,
        type-safe WhisperResponse structure.

        Args:
            response: API response (should have .model_dump())

        Returns:
            A WhisperVerboseJson dictionary
        """

        if hasattr(response, "model_dump"):
            data = response.model_dump(exclude_unset=True)
        elif hasattr(response, "to_dict"):
            data = response.to_dict()
        elif isinstance(response, dict):
            data = response
        elif isinstance(response, str):
            data = {"text": response} # mimic minimal data response format.
        else:
            raise ValueError(f"OpenAI response does not have a method to extract data "
                             f"(missing 'model_dump' or 'to_dict'): {repr(response)}")

        # Required field: duration
        duration = float(data.get("duration", 0.0))

        # Required field: text 
        text = data.get("text")
        if not isinstance(text, str):
            raise ValueError(f"Invalid response: 'text' must be a string, "
                             f"got {type(text)}")

        # Optional fields with normalization 
        language = data.get("language") or self.config.language or "unknown"
        if not isinstance(language, str):
            raise ValueError(f"Unexpected OpenAI response: 'language' is not a string."
                             f"got {type(language)}")

        # Optional: words and segments (only present in verbose_json)
        words = data.get("words")
        if words is not None and not isinstance(words, list):
            raise ValueError(f"Invalid 'words': expected list, got {type(words)}")

        segments = data.get("segments")
        if segments is not None and not isinstance(segments, list):
            raise ValueError(f"Invalid 'segments': expected list, got {type(segments)}")

        return WhisperResponse(
            text=text,
            language=language,
            duration=duration,
            words=words,
            segments=segments,
        )

    def set_api_key(self, api_key: Optional[str] = None) -> None:
        """
        Set or update the API key.

        This method allows refreshing the API key without re-instantiating the class.

        Args:
            api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)

        Raises:
            ValueError: If no API key is provided or found in environment
        """
        self.api_key = api_key or os.getenv("OPENAI_API_KEY")

        if not self.api_key:
            raise ValueError(
                "OpenAI API key is required. Set OPENAI_API_KEY environment "
                "variable or pass as api_key parameter."
            )

        # Configure OpenAI client
        openai.api_key = self.api_key
        logger.debug("API key updated")

    def _seconds_to_ms(self, seconds: Optional[float]) -> Optional[int]:
        """
        Convert seconds to milliseconds.

        Args:
            seconds: Time in seconds

        Returns:
            Time in milliseconds or None if seconds is None
        """
        return None if seconds is None else int(seconds * 1000)

    def _export_response(self, response: WhisperResponse) -> TranscriptionResult:
        """Process and validate WhisperResponse into TranscriptionResult."""       
        return TranscriptionResult(
            text=response["text"],
            language=response["language"],
            word_timing=self._extract_and_validate_words(response),
            utterance_timing=self._extract_and_validate_utterances(response),
            confidence=0.0,  # Whisper doesn't provide overall confidence
            audio_duration_ms=self._seconds_to_ms(response.get("duration")),
            transcript_id=None,  # No ID from Whisper
            status="completed",  # You can set a static "completed" status
            raw_result=dict(response),  # Store the original response for debugging
        )

    def _extract_and_validate_words(
        self, response: WhisperResponse
        ) -> TimedText:
        """Extract, validate, and convert word data into WordTiming objects."""
        words_data = response.get("words")
        units: list[TimedTextUnit] = []

        if words_data:
            for i, word_entry in enumerate(words_data, start=1):
                word = word_entry.get("word")
                start_ms = self._seconds_to_ms(word_entry.get("start"))
                end_ms = self._seconds_to_ms(word_entry.get("end"))

                if not isinstance(word, str) or not word:
                    logger.warning(f"Invalid or missing word: {word_entry}")
                    continue

                if not isinstance(start_ms, int) or not isinstance(end_ms, int):
                    logger.warning(f"Invalid timestamps for word: {word_entry}")
                    continue

                if start_ms > end_ms:
                    logger.warning(
                        f"Invalid timestamps: start ({start_ms}) > end ({end_ms}) "
                        f"for word: {word}. Setting end = start + 1."
                    )
                    end_ms = start_ms + 1

                if start_ms == end_ms:
                    # Workaround for OpenAI Whisper API bug:
                    # Sometimes start == end for word timestamps, which is invalid for downstream consumers.
                    logger.debug(
                        f"Whisper API returned identical start and end times "
                        f"({start_ms} ms) for word '{word}'. "
                        "Adjusting end_ms to start_ms + 1."
                    )
                    end_ms += 1

                units.append(
                    TimedTextUnit(
                        index=i,
                        text=word,
                        start_ms=start_ms,
                        end_ms=end_ms,
                        speaker=None,
                        granularity=Granularity.WORD,
                        confidence=0.0,
                    )
                )

        return TimedText(words=units, granularity=Granularity.WORD)

    def _extract_and_validate_utterances(
        self, response: WhisperResponse
        ) -> TimedText:
        """Extract and validate utterance segments into Utterance objects."""
        segments = response.get("segments")
        units: list[TimedTextUnit] = []

        if segments:
            for i, segment in enumerate(segments, start=1):
                start_ms = self._seconds_to_ms(segment.get("start"))
                end_ms = self._seconds_to_ms(segment.get("end"))
                text = segment.get("text", "")

                if not isinstance(start_ms, int) or not isinstance(end_ms, int):
                    logger.warning(f"Invalid segment timestamps: {segment}")
                    continue

                if not isinstance(text, str) or not text.strip():
                    logger.warning(f"Empty or invalid text for segment: {segment}")
                    continue

                units.append(
                    TimedTextUnit(
                        index=i,
                        text=text,
                        start_ms=start_ms,
                        end_ms=end_ms,
                        speaker=None,
                        granularity=Granularity.SEGMENT,
                        confidence=_logprob_to_confidence(segment.get("avg_logprob", 0.0)),
                    )
                )

        return TimedText(segments=units, granularity=Granularity.SEGMENT)

    def transcribe(
        self,
        audio_file: Union[Path, BytesIO],
        options: Optional[Dict[str, Any]] = None,
    ) -> TranscriptionResult:
        """
        Transcribe audio file to text using OpenAI Whisper API.

        PATCH: If audio_file is a file-like object, options['file_extension'] must be provided 
        (OpenAI API quirk).

        Args:
            audio_file: Path to audio file or file-like object
            options: Provider-specific options for transcription. 
                     If audio_file is file-like, must include 'file_extension'.

        Returns:
            Dictionary containing transcription results with standardized keys

        Raises:
            ValueError: If file-like object is provided without 'file_extension' in options
        """
        # Prepare file object
        file_obj, should_close = self._prepare_file_object(audio_file, options)
        try:
            return self._transcribe_execute(options, file_obj)
        except Exception as e:
            logger.error(f"Error during transcription: {e}")
            raise
        finally:
            # Clean up file object if we opened it
            if should_close:
                file_obj.close()

    def _transcribe_execute(self, options, file_obj):
        # Prepare API parameters
        api_params = self._prepare_api_params(options)
        api_params["file"] = file_obj

        # Call OpenAI API
        logger.info(f"Transcribing audio with Whisper API "
                    f"using model: {api_params['model']}")
        raw_response = openai.audio.transcriptions.create(**api_params)
        response = self._to_whisper_response(raw_response)

        result = self._export_response(response)

        logger.info("Transcription completed successfully")
        return result

    def get_result(self, job_id: str) -> TranscriptionResult:
        """
        Get results for an existing transcription job.

        Whisper API operates synchronously and doesn't use job IDs,
        so this method is not implemented.

        Args:
            job_id: ID of the transcription job

        Returns:
            Dictionary containing transcription results

        Raises:
            NotImplementedError: This method is not supported for Whisper
        """
        raise NotImplementedError(
            "Whisper API operates synchronously.\n"
            "Does not support retrieving results by job ID.\n"
            "Use the transcribe() method for direct transcription."
        )

    def transcribe_to_format(
        self,
        audio_file: Union[Path, BytesIO],
        format_type: str = "srt",
        transcription_options: Optional[Dict[str, Any]] = None,
        format_options: Optional[Dict[str, Any]] = None,
    ) -> str:
        """
        Transcribe audio and return result in specified format.

        PATCH: If audio_file is a file-like object, transcription_options['file_extension'] must be provided 
        (OpenAI API quirk).

        Takes advantage of the direct subtitle generation functionality when requesting SRT or VTT formats.

        Args:
            audio_file: Path, file-like object, or URL of audio file
            format_type: Format type (e.g., "srt", "vtt", "text")
            transcription_options: Options for transcription. If audio_file is file-like, must include 
                                   'file_extension'.
            format_options: Format-specific options

        Returns:
            String representation in the requested format

        Raises:
            ValueError: If file-like object is provided without 'file_extension' in transcription_options
        """
        format_type = format_type.lower()

        # If requesting SRT or VTT directly, use native OpenAI capabilities
        if format_type in {"srt", "vtt"}:
            # Create options with format set to SRT or VTT
            options = transcription_options.copy() if transcription_options else {}
            options["response_format"] = format_type

            # Prepare file object
            file_obj, should_close = self._prepare_file_object(audio_file, options)

            try:
                # Prepare API parameters
                api_params = self._prepare_api_params(options)
                api_params["file"] = file_obj

                # Call OpenAI API
                logger.info(f"Transcribing directly to {format_type} with Whisper API")
                return openai.audio.transcriptions.create(**api_params)
            finally:
                # Clean up file object if we opened it
                if should_close:
                    file_obj.close()

        # For other formats, use the format converter
        # First get a normal transcription result
        result = self.transcribe(audio_file, transcription_options)

        # Then convert to the requested format
        return self.format_converter.convert(
            result, format_type, format_options or {}
        )
config = WhisperConfig() instance-attribute
format_converter = FormatConverter() instance-attribute
__init__(api_key=None, **config_options)

Initialize the Whisper transcription service.

Parameters:

Name Type Description Default
api_key Optional[str]

OpenAI API key (defaults to OPENAI_API_KEY env var)

None
**config_options

Additional configuration options

{}
Source code in src/tnh_scholar/audio_processing/transcription/whisper_service.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def __init__(self, api_key: Optional[str] = None, **config_options):
    """
    Initialize the Whisper transcription service.

    Args:
        api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
        **config_options: Additional configuration options
    """
    # Create configuration base
    self.config = WhisperConfig()

    # Set any configuration options provided
    for key, value in config_options.items():
        if hasattr(self.config, key):
            setattr(self.config, key, value)

    # Validate configuration
    self.config.validate()

    # Initialize format converter
    self.format_converter = FormatConverter()

    # Set API key
    self.set_api_key(api_key)
get_result(job_id)

Get results for an existing transcription job.

Whisper API operates synchronously and doesn't use job IDs, so this method is not implemented.

Parameters:

Name Type Description Default
job_id str

ID of the transcription job

required

Returns:

Type Description
TranscriptionResult

Dictionary containing transcription results

Raises:

Type Description
NotImplementedError

This method is not supported for Whisper

Source code in src/tnh_scholar/audio_processing/transcription/whisper_service.py
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
def get_result(self, job_id: str) -> TranscriptionResult:
    """
    Get results for an existing transcription job.

    Whisper API operates synchronously and doesn't use job IDs,
    so this method is not implemented.

    Args:
        job_id: ID of the transcription job

    Returns:
        Dictionary containing transcription results

    Raises:
        NotImplementedError: This method is not supported for Whisper
    """
    raise NotImplementedError(
        "Whisper API operates synchronously.\n"
        "Does not support retrieving results by job ID.\n"
        "Use the transcribe() method for direct transcription."
    )
set_api_key(api_key=None)

Set or update the API key.

This method allows refreshing the API key without re-instantiating the class.

Parameters:

Name Type Description Default
api_key Optional[str]

OpenAI API key (defaults to OPENAI_API_KEY env var)

None

Raises:

Type Description
ValueError

If no API key is provided or found in environment

Source code in src/tnh_scholar/audio_processing/transcription/whisper_service.py
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
def set_api_key(self, api_key: Optional[str] = None) -> None:
    """
    Set or update the API key.

    This method allows refreshing the API key without re-instantiating the class.

    Args:
        api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)

    Raises:
        ValueError: If no API key is provided or found in environment
    """
    self.api_key = api_key or os.getenv("OPENAI_API_KEY")

    if not self.api_key:
        raise ValueError(
            "OpenAI API key is required. Set OPENAI_API_KEY environment "
            "variable or pass as api_key parameter."
        )

    # Configure OpenAI client
    openai.api_key = self.api_key
    logger.debug("API key updated")
transcribe(audio_file, options=None)

Transcribe audio file to text using OpenAI Whisper API.

PATCH: If audio_file is a file-like object, options['file_extension'] must be provided (OpenAI API quirk).

Parameters:

Name Type Description Default
audio_file Union[Path, BytesIO]

Path to audio file or file-like object

required
options Optional[Dict[str, Any]]

Provider-specific options for transcription. If audio_file is file-like, must include 'file_extension'.

None

Returns:

Type Description
TranscriptionResult

Dictionary containing transcription results with standardized keys

Raises:

Type Description
ValueError

If file-like object is provided without 'file_extension' in options

Source code in src/tnh_scholar/audio_processing/transcription/whisper_service.py
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
def transcribe(
    self,
    audio_file: Union[Path, BytesIO],
    options: Optional[Dict[str, Any]] = None,
) -> TranscriptionResult:
    """
    Transcribe audio file to text using OpenAI Whisper API.

    PATCH: If audio_file is a file-like object, options['file_extension'] must be provided 
    (OpenAI API quirk).

    Args:
        audio_file: Path to audio file or file-like object
        options: Provider-specific options for transcription. 
                 If audio_file is file-like, must include 'file_extension'.

    Returns:
        Dictionary containing transcription results with standardized keys

    Raises:
        ValueError: If file-like object is provided without 'file_extension' in options
    """
    # Prepare file object
    file_obj, should_close = self._prepare_file_object(audio_file, options)
    try:
        return self._transcribe_execute(options, file_obj)
    except Exception as e:
        logger.error(f"Error during transcription: {e}")
        raise
    finally:
        # Clean up file object if we opened it
        if should_close:
            file_obj.close()
transcribe_to_format(audio_file, format_type='srt', transcription_options=None, format_options=None)

Transcribe audio and return result in specified format.

PATCH: If audio_file is a file-like object, transcription_options['file_extension'] must be provided (OpenAI API quirk).

Takes advantage of the direct subtitle generation functionality when requesting SRT or VTT formats.

Parameters:

Name Type Description Default
audio_file Union[Path, BytesIO]

Path, file-like object, or URL of audio file

required
format_type str

Format type (e.g., "srt", "vtt", "text")

'srt'
transcription_options Optional[Dict[str, Any]]

Options for transcription. If audio_file is file-like, must include 'file_extension'.

None
format_options Optional[Dict[str, Any]]

Format-specific options

None

Returns:

Type Description
str

String representation in the requested format

Raises:

Type Description
ValueError

If file-like object is provided without 'file_extension' in transcription_options

Source code in src/tnh_scholar/audio_processing/transcription/whisper_service.py
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
def transcribe_to_format(
    self,
    audio_file: Union[Path, BytesIO],
    format_type: str = "srt",
    transcription_options: Optional[Dict[str, Any]] = None,
    format_options: Optional[Dict[str, Any]] = None,
) -> str:
    """
    Transcribe audio and return result in specified format.

    PATCH: If audio_file is a file-like object, transcription_options['file_extension'] must be provided 
    (OpenAI API quirk).

    Takes advantage of the direct subtitle generation functionality when requesting SRT or VTT formats.

    Args:
        audio_file: Path, file-like object, or URL of audio file
        format_type: Format type (e.g., "srt", "vtt", "text")
        transcription_options: Options for transcription. If audio_file is file-like, must include 
                               'file_extension'.
        format_options: Format-specific options

    Returns:
        String representation in the requested format

    Raises:
        ValueError: If file-like object is provided without 'file_extension' in transcription_options
    """
    format_type = format_type.lower()

    # If requesting SRT or VTT directly, use native OpenAI capabilities
    if format_type in {"srt", "vtt"}:
        # Create options with format set to SRT or VTT
        options = transcription_options.copy() if transcription_options else {}
        options["response_format"] = format_type

        # Prepare file object
        file_obj, should_close = self._prepare_file_object(audio_file, options)

        try:
            # Prepare API parameters
            api_params = self._prepare_api_params(options)
            api_params["file"] = file_obj

            # Call OpenAI API
            logger.info(f"Transcribing directly to {format_type} with Whisper API")
            return openai.audio.transcriptions.create(**api_params)
        finally:
            # Clean up file object if we opened it
            if should_close:
                file_obj.close()

    # For other formats, use the format converter
    # First get a normal transcription result
    result = self.transcribe(audio_file, transcription_options)

    # Then convert to the requested format
    return self.format_converter.convert(
        result, format_type, format_options or {}
    )
WordEntry

Bases: TypedDict

Source code in src/tnh_scholar/audio_processing/transcription/whisper_service.py
61
62
63
64
class WordEntry(TypedDict, total=False):
    word: str
    start: Optional[float]
    end: Optional[float]
end instance-attribute
start instance-attribute
word instance-attribute

utils

__all__ = ['AudioEnhancer', 'get_segment_audio', 'play_audio_segment', 'play_bytes', 'play_from_file', 'play_diarization_segment', 'get_audio_from_file'] module-attribute
AudioEnhancer
Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
class AudioEnhancer:
    def __init__(
        self, 
        config: EnhancementConfig = EnhancementConfig(), 
        compression_settings: CompressionSettings = CompressionSettings()
        ):
        """Initialize with enhancement configuration and compression settings."""

        # Check required tools
        for tool in ["sox", "ffmpeg"]:
            try:
                subprocess.run(["which", tool], capture_output=True, text=True, check=True)
            except (subprocess.SubprocessError, FileNotFoundError) as e:
                raise RuntimeError(f"{tool} is not installed. Please install it first.") from e

        self.config = config
        self.compression_settings = compression_settings

    def enhance(self, input_path: Path, output_path: Optional[Path] = None) -> Path:
        """
        Apply enhancement routines (compression, EQ, gating, etc.) in a modular fashion.
        Converts input to FLAC working format for Whisper compatibility.
        """
        input_path = Path(input_path)
        if output_path is None:
            output_path = input_path.parent / f"{input_path.stem}_enhanced.flac"

        # Step 1: Convert to FLAC if needed
        working_flac = input_path.with_suffix(".flac")
        if not working_flac.exists():
            self._convert_to_flac(input_path, working_flac)

        # Step 2: Build SoX command modularly using helper methods
        sox_cmd = ["sox", str(working_flac), str(output_path)]
        sox_cmd.extend(self._set_remix())
        sox_cmd.extend(self._set_rate())
        sox_cmd.extend(self._set_gain())
        sox_cmd.extend(self._set_freq())
        sox_cmd.extend(self._set_eq())
        sox_cmd.extend(self._set_compand())
        sox_cmd.extend(self._set_gate())
        sox_cmd.extend(self._set_contrast_bass_treble())
        sox_cmd.extend(self._set_norm())

        result = subprocess.run(sox_cmd, capture_output=True, text=True)
        if result.returncode != 0:
            logger.info(f"SoX Error: {result.stderr}")
            raise RuntimeError(f"SoX processing failed: {result.stderr}")
        return output_path

    def _set_remix(self) -> list[str]:
        """Set remix channels if force_mono is enabled."""
        if self.config.force_mono:
            return ["remix", self.config.remix.remix_channels]
        return []

    def _set_rate(self) -> list[str]:
        """Set sample rate if target_rate is specified."""
        if self.config.target_rate:
            return ["rate", *self.config.rate.rate_args, str(self.config.target_rate)]
        return []

    def _set_gain(self) -> list[str]:
        """Set gain normalization."""
        return ["gain", "-n", str(self.config.norm.norm_level)]

    def _set_freq(self) -> list[str]:
        """Set highpass and lowpass frequencies."""
        return [
            "highpass", "-1", str(self.config.eq.highpass_freq),
            "lowpass", "-1", str(self.config.eq.lowpass_freq)
        ]

    def _set_eq(self) -> list[str]:
        """Set equalizer bands."""
        eq_cmd = []
        for freq, width, gain in self.config.eq.eq_bands:
            eq_cmd.extend(["equalizer", str(freq), str(width), str(gain)])
        return eq_cmd

    def _set_compand(self) -> list[str]:
        """Set compression arguments."""
        comp_args: list[str] = getattr(
            self.compression_settings,
            self.config.compression_level,
            self.compression_settings.whisper_optimized
        )
        return ["compand", *comp_args, ":"]

    def _set_gate(self) -> list[str]:
        """Set gate if enabled."""
        if self.config.include_gate:
            return ["gate", *self.config.gate.gate_params]
        return []

    def _set_contrast_bass_treble(self) -> list[str]:
        """Set contrast, bass, and treble if EQ is enabled."""
        if self.config.include_eq:
            return [
                "contrast", str(self.config.eq.contrast),
                "bass", str(self.config.eq.bass[0]), str(self.config.eq.bass[1]),
                "treble", f"+{self.config.eq.treble[0]}", str(self.config.eq.treble[1])
            ]
        return []

    def _set_norm(self) -> list[str]:
        """Set normalization."""
        return ["norm", str(self.config.norm.norm_level)]

    def _convert_to_flac(self, input_path: Path, output_path: Path) -> None:
        """
        Convert input audio to FLAC format using ffmpeg, preserving maximal fidelity.
        """
        cmd = [
            "ffmpeg", "-i", str(input_path),
            "-map", "0:a:0",
            "-c:a", "flac",
            "-compression_level", "8",
            str(output_path),
            "-y"
        ]
        result = subprocess.run(cmd, check=True, capture_output=True)
        if result.returncode != 0:
            print(f"FFmpeg Error: {result.stderr.decode()}")
            raise RuntimeError(f"FFmpeg conversion failed: {result.stderr.decode()}")

    def extract_sample(
        self,
        input_path: Path,
        start: float,
        duration: float,
        output_path: Optional[Path] = None,
        output_format: str = "flac",
        codec: Optional[str] = None,
        compression_level: int = 8,
    ) -> Path:
        """
        Extract a sample segment from the audio file.

        Parameters
        ----------
        input_path : Path
            Path to the input audio file.
        start : float
            Start time in seconds.
        duration : float
            Duration in seconds.
        output_path : Path, optional
            Output file path. If None, auto-generated from input.
        output_format : str, default="flac"
            Output audio format/extension.
        codec : str, optional
            Audio codec to use (default: "flac" if output_format is "flac", else None).
        compression_level : int, default=8
            Compression level for supported codecs.

        Returns
        -------
        Path
            Path to the extracted audio sample.
        """
        input_path = Path(input_path)
        output_path = self._sample_output_path(input_path, output_path, start, duration, output_format)

        if codec is None:
            codec = "flac" if output_format == "flac" else None

        cmd = [
            "ffmpeg", "-y",
            "-ss", str(start),
            "-t", str(duration),
            "-i", str(input_path),
        ]
        if codec:
            cmd += ["-c:a", codec]
        if codec == "flac":
            cmd += ["-compression_level", str(compression_level)]
        cmd.append(str(output_path))

        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode != 0:
            logger.error(f"FFmpeg sample extraction failed: {result.stderr}")
            raise RuntimeError(f"Sample extraction failed: {result.stderr}")
        return output_path

    def _sample_output_path(self, input_path, output_path, start, duration, output_format) -> Path:
        if output_path is None:
            return ( 
                    input_path.parent / 
                    f"{input_path.stem}_sample_{int(start)}s_{int(duration)}s.{output_format}"
            )
        return Path(output_path)

    def play_audio(self, file_path: Path):
        """Play audio in notebook for quality assessment."""
        display(Audio(str(file_path)))

    def get_audio_info(self, file_path: Path):
        """Get detailed audio information using ffprobe."""
        cmd = [
            "ffprobe", "-v", "quiet", "-print_format", "json",
            "-show_streams", "-select_streams", "a:0", str(file_path)
        ]
        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode == 0:
            return self._display_stream_info(result, file_path)
        logger.error(f"FFprobe error: {result.stderr}")
        raise RuntimeError("Failed to retrieve audio info.")

    def _display_stream_info(self, result: subprocess.CompletedProcess, file_path: Path) -> dict:
        data = json.loads(result.stdout)
        stream = data["streams"][0]

        logger.info(f"File: {file_path}")
        logger.info(f"Codec: {stream.get('codec_name', 'Unknown')}")
        logger.info(f"Sample Rate: {stream.get('sample_rate', 'Unknown')} Hz")
        logger.info(f"Channels: {stream.get('channels', 'Unknown')}")
        logger.info(f"Bit Rate: {stream.get('bit_rate', 'Unknown')} bps")
        logger.info(f"Duration: {stream.get('duration', 'Unknown')} seconds")
        logger.info(f"Sample Format: {stream.get('sample_fmt', 'Unknown')}")

        return stream
compression_settings = compression_settings instance-attribute
config = config instance-attribute
__init__(config=EnhancementConfig(), compression_settings=CompressionSettings())

Initialize with enhancement configuration and compression settings.

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def __init__(
    self, 
    config: EnhancementConfig = EnhancementConfig(), 
    compression_settings: CompressionSettings = CompressionSettings()
    ):
    """Initialize with enhancement configuration and compression settings."""

    # Check required tools
    for tool in ["sox", "ffmpeg"]:
        try:
            subprocess.run(["which", tool], capture_output=True, text=True, check=True)
        except (subprocess.SubprocessError, FileNotFoundError) as e:
            raise RuntimeError(f"{tool} is not installed. Please install it first.") from e

    self.config = config
    self.compression_settings = compression_settings
enhance(input_path, output_path=None)

Apply enhancement routines (compression, EQ, gating, etc.) in a modular fashion. Converts input to FLAC working format for Whisper compatibility.

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def enhance(self, input_path: Path, output_path: Optional[Path] = None) -> Path:
    """
    Apply enhancement routines (compression, EQ, gating, etc.) in a modular fashion.
    Converts input to FLAC working format for Whisper compatibility.
    """
    input_path = Path(input_path)
    if output_path is None:
        output_path = input_path.parent / f"{input_path.stem}_enhanced.flac"

    # Step 1: Convert to FLAC if needed
    working_flac = input_path.with_suffix(".flac")
    if not working_flac.exists():
        self._convert_to_flac(input_path, working_flac)

    # Step 2: Build SoX command modularly using helper methods
    sox_cmd = ["sox", str(working_flac), str(output_path)]
    sox_cmd.extend(self._set_remix())
    sox_cmd.extend(self._set_rate())
    sox_cmd.extend(self._set_gain())
    sox_cmd.extend(self._set_freq())
    sox_cmd.extend(self._set_eq())
    sox_cmd.extend(self._set_compand())
    sox_cmd.extend(self._set_gate())
    sox_cmd.extend(self._set_contrast_bass_treble())
    sox_cmd.extend(self._set_norm())

    result = subprocess.run(sox_cmd, capture_output=True, text=True)
    if result.returncode != 0:
        logger.info(f"SoX Error: {result.stderr}")
        raise RuntimeError(f"SoX processing failed: {result.stderr}")
    return output_path
extract_sample(input_path, start, duration, output_path=None, output_format='flac', codec=None, compression_level=8)

Extract a sample segment from the audio file.

Parameters

input_path : Path Path to the input audio file. start : float Start time in seconds. duration : float Duration in seconds. output_path : Path, optional Output file path. If None, auto-generated from input. output_format : str, default="flac" Output audio format/extension. codec : str, optional Audio codec to use (default: "flac" if output_format is "flac", else None). compression_level : int, default=8 Compression level for supported codecs.

Returns

Path Path to the extracted audio sample.

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def extract_sample(
    self,
    input_path: Path,
    start: float,
    duration: float,
    output_path: Optional[Path] = None,
    output_format: str = "flac",
    codec: Optional[str] = None,
    compression_level: int = 8,
) -> Path:
    """
    Extract a sample segment from the audio file.

    Parameters
    ----------
    input_path : Path
        Path to the input audio file.
    start : float
        Start time in seconds.
    duration : float
        Duration in seconds.
    output_path : Path, optional
        Output file path. If None, auto-generated from input.
    output_format : str, default="flac"
        Output audio format/extension.
    codec : str, optional
        Audio codec to use (default: "flac" if output_format is "flac", else None).
    compression_level : int, default=8
        Compression level for supported codecs.

    Returns
    -------
    Path
        Path to the extracted audio sample.
    """
    input_path = Path(input_path)
    output_path = self._sample_output_path(input_path, output_path, start, duration, output_format)

    if codec is None:
        codec = "flac" if output_format == "flac" else None

    cmd = [
        "ffmpeg", "-y",
        "-ss", str(start),
        "-t", str(duration),
        "-i", str(input_path),
    ]
    if codec:
        cmd += ["-c:a", codec]
    if codec == "flac":
        cmd += ["-compression_level", str(compression_level)]
    cmd.append(str(output_path))

    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        logger.error(f"FFmpeg sample extraction failed: {result.stderr}")
        raise RuntimeError(f"Sample extraction failed: {result.stderr}")
    return output_path
get_audio_info(file_path)

Get detailed audio information using ffprobe.

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
338
339
340
341
342
343
344
345
346
347
348
def get_audio_info(self, file_path: Path):
    """Get detailed audio information using ffprobe."""
    cmd = [
        "ffprobe", "-v", "quiet", "-print_format", "json",
        "-show_streams", "-select_streams", "a:0", str(file_path)
    ]
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode == 0:
        return self._display_stream_info(result, file_path)
    logger.error(f"FFprobe error: {result.stderr}")
    raise RuntimeError("Failed to retrieve audio info.")
play_audio(file_path)

Play audio in notebook for quality assessment.

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
334
335
336
def play_audio(self, file_path: Path):
    """Play audio in notebook for quality assessment."""
    display(Audio(str(file_path)))
get_audio_from_file(audio_file)
Source code in src/tnh_scholar/audio_processing/utils/playback.py
22
23
24
def get_audio_from_file(audio_file: Path) -> AudioSegment:
    audio_format_ext = audio_file.suffix.lstrip(".").lower()
    return AudioSegment.from_file(audio_file, format=audio_format_ext)
get_segment_audio(segment, audio)
Source code in src/tnh_scholar/audio_processing/utils/playback.py
16
17
def get_segment_audio(segment: DiarizedSegment, audio: AudioSegment):
    return audio[segment.start:segment.end]
play_audio_segment(audio)
Source code in src/tnh_scholar/audio_processing/utils/playback.py
19
20
def play_audio_segment(audio: AudioSegment):
    play(audio)
play_bytes(data, format='wav')
Source code in src/tnh_scholar/audio_processing/utils/playback.py
30
31
32
def play_bytes(data: BytesIO, format: str = "wav"):
    audio = AudioSegment.from_file(data, format=format)
    play(audio)
play_diarization_segment(segment, audio)
Source code in src/tnh_scholar/audio_processing/utils/playback.py
13
14
def play_diarization_segment(segment: DiarizedSegment, audio: AudioSegment):
    play_audio_segment(audio[segment.start:segment.end]) 
play_from_file(path)
Source code in src/tnh_scholar/audio_processing/utils/playback.py
26
27
28
def play_from_file(path: Path):
    audio = AudioSegment.from_file(path)
    play(audio)
audio_enhance

Module review and recommendations:

Big Picture Approach:

Modular, Configurable, and Extensible: Your use of Pydantic models for settings and configs is excellent. It makes the pipeline flexible and easy to tune for different ASR or enhancement needs. Tooling: Leveraging SoX and FFmpeg is a pragmatic choice for robust, high-quality audio processing. Pipeline Structure: The AudioEnhancer class is well-structured, with clear separation of concerns for each processing step (remix, rate, gain, EQ, compand, etc.). Notebook Integration: The play_audio method and use of IPython display is great for interactive, iterative work.

Details & Points You Might Be Missing:

Error Handling & Logging:

You print errors but could benefit from more structured logging (e.g., using Python’s logging module). Consider more granular exception handling, especially for subprocess calls. Testing & Validation:

No unit tests or validation of output audio quality/format are present. Consider adding automated tests (even if just for file existence, format, and basic properties). You could add a method to compare pre/post enhancement SNR, loudness, or other metrics. Documentation & Examples:

While docstrings are good, a usage example (in code or markdown) would help new users. Consider a README or notebook cell that demonstrates a full workflow. Performance:

For large-scale or batch processing, consider parallelization or async processing. Temporary files (e.g., intermediate FLACs) could be managed/cleaned up more robustly. Extensibility:

The pipeline is modular, but adding a “custom steps” hook (e.g., user-defined SoX/FFmpeg args) would make it even more flexible. You might want to support other codecs or output formats for downstream ASR models. Feature Gaps:

The extract_sample method is a TODO. Implementing this would be useful for quick QA or dataset creation. Consider adding Voice Activity Detection (VAD) or silence trimming as optional steps. You could add a “dry run” mode to print the SoX/FFmpeg commands without executing, for debugging. ASR-Specific Enhancements:

You might want to add preset configs for different ASR models (e.g., Whisper, Wav2Vec2, etc.), as they may have different optimal preprocessing. Consider integrating with open-source ASR evaluation tools to close the loop on enhancement effectiveness. General Strategic Recommendations:

Automate QA: Add methods to check output audio quality, duration, and format, and optionally compare to input. Batch Processing: Add a method to process a directory or list of files. Config Export/Import: Allow saving/loading configs as JSON/YAML for reproducibility. CLI/Script Interface: Consider a command-line interface for use outside notebooks. Unit Tests: Add basic tests for each method, especially for error cases. Summary Table:

| Modularity | Good | Add custom step hooks | | Configurability | Excellent | Presets for more ASR models | | Error Handling | Basic | Use logging, more granular exceptions | | Testing | Missing | Add unit tests, output validation | | Documentation | Good | Add usage examples, README | | Extensibility | Good | Support more codecs, batch processing | | ASR Optimization | Good start | Add VAD, silence trim, model-specific configs |

logger = get_child_logger(__name__) module-attribute
AudioEnhancer
Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
class AudioEnhancer:
    def __init__(
        self, 
        config: EnhancementConfig = EnhancementConfig(), 
        compression_settings: CompressionSettings = CompressionSettings()
        ):
        """Initialize with enhancement configuration and compression settings."""

        # Check required tools
        for tool in ["sox", "ffmpeg"]:
            try:
                subprocess.run(["which", tool], capture_output=True, text=True, check=True)
            except (subprocess.SubprocessError, FileNotFoundError) as e:
                raise RuntimeError(f"{tool} is not installed. Please install it first.") from e

        self.config = config
        self.compression_settings = compression_settings

    def enhance(self, input_path: Path, output_path: Optional[Path] = None) -> Path:
        """
        Apply enhancement routines (compression, EQ, gating, etc.) in a modular fashion.
        Converts input to FLAC working format for Whisper compatibility.
        """
        input_path = Path(input_path)
        if output_path is None:
            output_path = input_path.parent / f"{input_path.stem}_enhanced.flac"

        # Step 1: Convert to FLAC if needed
        working_flac = input_path.with_suffix(".flac")
        if not working_flac.exists():
            self._convert_to_flac(input_path, working_flac)

        # Step 2: Build SoX command modularly using helper methods
        sox_cmd = ["sox", str(working_flac), str(output_path)]
        sox_cmd.extend(self._set_remix())
        sox_cmd.extend(self._set_rate())
        sox_cmd.extend(self._set_gain())
        sox_cmd.extend(self._set_freq())
        sox_cmd.extend(self._set_eq())
        sox_cmd.extend(self._set_compand())
        sox_cmd.extend(self._set_gate())
        sox_cmd.extend(self._set_contrast_bass_treble())
        sox_cmd.extend(self._set_norm())

        result = subprocess.run(sox_cmd, capture_output=True, text=True)
        if result.returncode != 0:
            logger.info(f"SoX Error: {result.stderr}")
            raise RuntimeError(f"SoX processing failed: {result.stderr}")
        return output_path

    def _set_remix(self) -> list[str]:
        """Set remix channels if force_mono is enabled."""
        if self.config.force_mono:
            return ["remix", self.config.remix.remix_channels]
        return []

    def _set_rate(self) -> list[str]:
        """Set sample rate if target_rate is specified."""
        if self.config.target_rate:
            return ["rate", *self.config.rate.rate_args, str(self.config.target_rate)]
        return []

    def _set_gain(self) -> list[str]:
        """Set gain normalization."""
        return ["gain", "-n", str(self.config.norm.norm_level)]

    def _set_freq(self) -> list[str]:
        """Set highpass and lowpass frequencies."""
        return [
            "highpass", "-1", str(self.config.eq.highpass_freq),
            "lowpass", "-1", str(self.config.eq.lowpass_freq)
        ]

    def _set_eq(self) -> list[str]:
        """Set equalizer bands."""
        eq_cmd = []
        for freq, width, gain in self.config.eq.eq_bands:
            eq_cmd.extend(["equalizer", str(freq), str(width), str(gain)])
        return eq_cmd

    def _set_compand(self) -> list[str]:
        """Set compression arguments."""
        comp_args: list[str] = getattr(
            self.compression_settings,
            self.config.compression_level,
            self.compression_settings.whisper_optimized
        )
        return ["compand", *comp_args, ":"]

    def _set_gate(self) -> list[str]:
        """Set gate if enabled."""
        if self.config.include_gate:
            return ["gate", *self.config.gate.gate_params]
        return []

    def _set_contrast_bass_treble(self) -> list[str]:
        """Set contrast, bass, and treble if EQ is enabled."""
        if self.config.include_eq:
            return [
                "contrast", str(self.config.eq.contrast),
                "bass", str(self.config.eq.bass[0]), str(self.config.eq.bass[1]),
                "treble", f"+{self.config.eq.treble[0]}", str(self.config.eq.treble[1])
            ]
        return []

    def _set_norm(self) -> list[str]:
        """Set normalization."""
        return ["norm", str(self.config.norm.norm_level)]

    def _convert_to_flac(self, input_path: Path, output_path: Path) -> None:
        """
        Convert input audio to FLAC format using ffmpeg, preserving maximal fidelity.
        """
        cmd = [
            "ffmpeg", "-i", str(input_path),
            "-map", "0:a:0",
            "-c:a", "flac",
            "-compression_level", "8",
            str(output_path),
            "-y"
        ]
        result = subprocess.run(cmd, check=True, capture_output=True)
        if result.returncode != 0:
            print(f"FFmpeg Error: {result.stderr.decode()}")
            raise RuntimeError(f"FFmpeg conversion failed: {result.stderr.decode()}")

    def extract_sample(
        self,
        input_path: Path,
        start: float,
        duration: float,
        output_path: Optional[Path] = None,
        output_format: str = "flac",
        codec: Optional[str] = None,
        compression_level: int = 8,
    ) -> Path:
        """
        Extract a sample segment from the audio file.

        Parameters
        ----------
        input_path : Path
            Path to the input audio file.
        start : float
            Start time in seconds.
        duration : float
            Duration in seconds.
        output_path : Path, optional
            Output file path. If None, auto-generated from input.
        output_format : str, default="flac"
            Output audio format/extension.
        codec : str, optional
            Audio codec to use (default: "flac" if output_format is "flac", else None).
        compression_level : int, default=8
            Compression level for supported codecs.

        Returns
        -------
        Path
            Path to the extracted audio sample.
        """
        input_path = Path(input_path)
        output_path = self._sample_output_path(input_path, output_path, start, duration, output_format)

        if codec is None:
            codec = "flac" if output_format == "flac" else None

        cmd = [
            "ffmpeg", "-y",
            "-ss", str(start),
            "-t", str(duration),
            "-i", str(input_path),
        ]
        if codec:
            cmd += ["-c:a", codec]
        if codec == "flac":
            cmd += ["-compression_level", str(compression_level)]
        cmd.append(str(output_path))

        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode != 0:
            logger.error(f"FFmpeg sample extraction failed: {result.stderr}")
            raise RuntimeError(f"Sample extraction failed: {result.stderr}")
        return output_path

    def _sample_output_path(self, input_path, output_path, start, duration, output_format) -> Path:
        if output_path is None:
            return ( 
                    input_path.parent / 
                    f"{input_path.stem}_sample_{int(start)}s_{int(duration)}s.{output_format}"
            )
        return Path(output_path)

    def play_audio(self, file_path: Path):
        """Play audio in notebook for quality assessment."""
        display(Audio(str(file_path)))

    def get_audio_info(self, file_path: Path):
        """Get detailed audio information using ffprobe."""
        cmd = [
            "ffprobe", "-v", "quiet", "-print_format", "json",
            "-show_streams", "-select_streams", "a:0", str(file_path)
        ]
        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode == 0:
            return self._display_stream_info(result, file_path)
        logger.error(f"FFprobe error: {result.stderr}")
        raise RuntimeError("Failed to retrieve audio info.")

    def _display_stream_info(self, result: subprocess.CompletedProcess, file_path: Path) -> dict:
        data = json.loads(result.stdout)
        stream = data["streams"][0]

        logger.info(f"File: {file_path}")
        logger.info(f"Codec: {stream.get('codec_name', 'Unknown')}")
        logger.info(f"Sample Rate: {stream.get('sample_rate', 'Unknown')} Hz")
        logger.info(f"Channels: {stream.get('channels', 'Unknown')}")
        logger.info(f"Bit Rate: {stream.get('bit_rate', 'Unknown')} bps")
        logger.info(f"Duration: {stream.get('duration', 'Unknown')} seconds")
        logger.info(f"Sample Format: {stream.get('sample_fmt', 'Unknown')}")

        return stream
compression_settings = compression_settings instance-attribute
config = config instance-attribute
__init__(config=EnhancementConfig(), compression_settings=CompressionSettings())

Initialize with enhancement configuration and compression settings.

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def __init__(
    self, 
    config: EnhancementConfig = EnhancementConfig(), 
    compression_settings: CompressionSettings = CompressionSettings()
    ):
    """Initialize with enhancement configuration and compression settings."""

    # Check required tools
    for tool in ["sox", "ffmpeg"]:
        try:
            subprocess.run(["which", tool], capture_output=True, text=True, check=True)
        except (subprocess.SubprocessError, FileNotFoundError) as e:
            raise RuntimeError(f"{tool} is not installed. Please install it first.") from e

    self.config = config
    self.compression_settings = compression_settings
enhance(input_path, output_path=None)

Apply enhancement routines (compression, EQ, gating, etc.) in a modular fashion. Converts input to FLAC working format for Whisper compatibility.

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def enhance(self, input_path: Path, output_path: Optional[Path] = None) -> Path:
    """
    Apply enhancement routines (compression, EQ, gating, etc.) in a modular fashion.
    Converts input to FLAC working format for Whisper compatibility.
    """
    input_path = Path(input_path)
    if output_path is None:
        output_path = input_path.parent / f"{input_path.stem}_enhanced.flac"

    # Step 1: Convert to FLAC if needed
    working_flac = input_path.with_suffix(".flac")
    if not working_flac.exists():
        self._convert_to_flac(input_path, working_flac)

    # Step 2: Build SoX command modularly using helper methods
    sox_cmd = ["sox", str(working_flac), str(output_path)]
    sox_cmd.extend(self._set_remix())
    sox_cmd.extend(self._set_rate())
    sox_cmd.extend(self._set_gain())
    sox_cmd.extend(self._set_freq())
    sox_cmd.extend(self._set_eq())
    sox_cmd.extend(self._set_compand())
    sox_cmd.extend(self._set_gate())
    sox_cmd.extend(self._set_contrast_bass_treble())
    sox_cmd.extend(self._set_norm())

    result = subprocess.run(sox_cmd, capture_output=True, text=True)
    if result.returncode != 0:
        logger.info(f"SoX Error: {result.stderr}")
        raise RuntimeError(f"SoX processing failed: {result.stderr}")
    return output_path
extract_sample(input_path, start, duration, output_path=None, output_format='flac', codec=None, compression_level=8)

Extract a sample segment from the audio file.

Parameters

input_path : Path Path to the input audio file. start : float Start time in seconds. duration : float Duration in seconds. output_path : Path, optional Output file path. If None, auto-generated from input. output_format : str, default="flac" Output audio format/extension. codec : str, optional Audio codec to use (default: "flac" if output_format is "flac", else None). compression_level : int, default=8 Compression level for supported codecs.

Returns

Path Path to the extracted audio sample.

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def extract_sample(
    self,
    input_path: Path,
    start: float,
    duration: float,
    output_path: Optional[Path] = None,
    output_format: str = "flac",
    codec: Optional[str] = None,
    compression_level: int = 8,
) -> Path:
    """
    Extract a sample segment from the audio file.

    Parameters
    ----------
    input_path : Path
        Path to the input audio file.
    start : float
        Start time in seconds.
    duration : float
        Duration in seconds.
    output_path : Path, optional
        Output file path. If None, auto-generated from input.
    output_format : str, default="flac"
        Output audio format/extension.
    codec : str, optional
        Audio codec to use (default: "flac" if output_format is "flac", else None).
    compression_level : int, default=8
        Compression level for supported codecs.

    Returns
    -------
    Path
        Path to the extracted audio sample.
    """
    input_path = Path(input_path)
    output_path = self._sample_output_path(input_path, output_path, start, duration, output_format)

    if codec is None:
        codec = "flac" if output_format == "flac" else None

    cmd = [
        "ffmpeg", "-y",
        "-ss", str(start),
        "-t", str(duration),
        "-i", str(input_path),
    ]
    if codec:
        cmd += ["-c:a", codec]
    if codec == "flac":
        cmd += ["-compression_level", str(compression_level)]
    cmd.append(str(output_path))

    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        logger.error(f"FFmpeg sample extraction failed: {result.stderr}")
        raise RuntimeError(f"Sample extraction failed: {result.stderr}")
    return output_path
get_audio_info(file_path)

Get detailed audio information using ffprobe.

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
338
339
340
341
342
343
344
345
346
347
348
def get_audio_info(self, file_path: Path):
    """Get detailed audio information using ffprobe."""
    cmd = [
        "ffprobe", "-v", "quiet", "-print_format", "json",
        "-show_streams", "-select_streams", "a:0", str(file_path)
    ]
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode == 0:
        return self._display_stream_info(result, file_path)
    logger.error(f"FFprobe error: {result.stderr}")
    raise RuntimeError("Failed to retrieve audio info.")
play_audio(file_path)

Play audio in notebook for quality assessment.

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
334
335
336
def play_audio(self, file_path: Path):
    """Play audio in notebook for quality assessment."""
    display(Audio(str(file_path)))
CompressionSettings

Bases: BaseSettings

Compression settings for audio enhancement routines.

Attributes:

Name Type Description
minimal list[str]

List of compand arguments for minimal compression.

light list[str]

List of compand arguments for light compression.

moderate list[str]

List of compand arguments for moderate compression.

aggressive list[str]

List of compand arguments for aggressive compression.

whisper_optimized list[str]

List of compand arguments for Whisper-optimized compression.

whisper_aggressive list[str]

List of compand arguments for aggressive Whisper compression.

primary_speech_only list[str]

List of compand arguments for primary speech only.

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class CompressionSettings(BaseSettings):
    """Compression settings for audio enhancement routines.

    Attributes:
        minimal: List of compand arguments for minimal compression.
        light: List of compand arguments for light compression.
        moderate: List of compand arguments for moderate compression.
        aggressive: List of compand arguments for aggressive compression.
        whisper_optimized: List of compand arguments for Whisper-optimized compression.
        whisper_aggressive: List of compand arguments for aggressive Whisper compression.
        primary_speech_only: List of compand arguments for primary speech only.
    """
    minimal: list[str] = ["0.1,0.3", "3:-50,-40,-30,-20", "-3", "-80", "0.2"]
    light: list[str] = ["0.05,0.2", "6:-60,-50,-40,-30,-20,-10", "-3", "-85", "0.1"]
    moderate: list[str] = ["0.03,0.15", "6:-65,-50,-40,-30,-20,-10", "-4", "-85", "0.1"]
    aggressive: list[str] = ["0.02,0.1", "8:-70,-55,-45,-35,-25,-15", "-5", "-90", "0.05"]
    whisper_optimized: list[str] = ["0.005,0.06", "12:-75,-65,-55,-45,-35,-25,-15,-8", "-8", "-95", "0.03"]
    whisper_aggressive: list[str] = ["0.005,0.06", "12:-75,-45,-55,-30,-35,-18,-15,-8", "-8", "-95", "0.03"]
    primary_speech_only: list[str] = ["0.005,0.06", "12:-60,-45,-55,-30,-35,-18,-15,-8", "-8", "-60", "0.03"]
aggressive = ['0.02,0.1', '8:-70,-55,-45,-35,-25,-15', '-5', '-90', '0.05'] class-attribute instance-attribute
light = ['0.05,0.2', '6:-60,-50,-40,-30,-20,-10', '-3', '-85', '0.1'] class-attribute instance-attribute
minimal = ['0.1,0.3', '3:-50,-40,-30,-20', '-3', '-80', '0.2'] class-attribute instance-attribute
moderate = ['0.03,0.15', '6:-65,-50,-40,-30,-20,-10', '-4', '-85', '0.1'] class-attribute instance-attribute
primary_speech_only = ['0.005,0.06', '12:-60,-45,-55,-30,-35,-18,-15,-8', '-8', '-60', '0.03'] class-attribute instance-attribute
whisper_aggressive = ['0.005,0.06', '12:-75,-45,-55,-30,-35,-18,-15,-8', '-8', '-95', '0.03'] class-attribute instance-attribute
whisper_optimized = ['0.005,0.06', '12:-75,-65,-55,-45,-35,-25,-15,-8', '-8', '-95', '0.03'] class-attribute instance-attribute
EQSettings

Bases: BaseSettings

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
101
102
103
104
105
106
107
108
109
110
111
112
class EQSettings(BaseSettings):
    highpass_freq: int = 175
    lowpass_freq: int = 15000
    eq_bands: list[tuple[int, float, int]] = [
        (100, 0.9, -20),
        (1500, 1, 4),
        (4000, 0.6, 15),
        (10000, 1, -10)
    ]
    contrast: int = 75
    bass: tuple[int, int] = (-5, 200)
    treble: tuple[int, int] = (3, 3000)
bass = (-5, 200) class-attribute instance-attribute
contrast = 75 class-attribute instance-attribute
eq_bands = [(100, 0.9, -20), (1500, 1, 4), (4000, 0.6, 15), (10000, 1, -10)] class-attribute instance-attribute
highpass_freq = 175 class-attribute instance-attribute
lowpass_freq = 15000 class-attribute instance-attribute
treble = (3, 3000) class-attribute instance-attribute
EnhancementConfig

Bases: BaseModel

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class EnhancementConfig(BaseModel):
    codec: str = 'flac'
    sample_rate: int = 48000
    channels: int = 2
    compression_level: str = 'aggressive'
    force_mono: bool = False
    target_rate: Optional[int] = None
    eq: EQSettings = EQSettings()
    gate: GateSettings = GateSettings()
    norm: NormalizationSettings = NormalizationSettings()
    remix: RemixSettings = RemixSettings()
    rate: RateSettings = RateSettings()
    include_gate: bool = True
    include_eq: bool = True
channels = 2 class-attribute instance-attribute
codec = 'flac' class-attribute instance-attribute
compression_level = 'aggressive' class-attribute instance-attribute
eq = EQSettings() class-attribute instance-attribute
force_mono = False class-attribute instance-attribute
gate = GateSettings() class-attribute instance-attribute
include_eq = True class-attribute instance-attribute
include_gate = True class-attribute instance-attribute
norm = NormalizationSettings() class-attribute instance-attribute
rate = RateSettings() class-attribute instance-attribute
remix = RemixSettings() class-attribute instance-attribute
sample_rate = 48000 class-attribute instance-attribute
target_rate = None class-attribute instance-attribute
GateSettings

Bases: BaseSettings

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
114
115
class GateSettings(BaseSettings):
    gate_params: list[str] = ["0.1", "0.05", "-inf", "0.1", "-90", "0.1"]
gate_params = ['0.1', '0.05', '-inf', '0.1', '-90', '0.1'] class-attribute instance-attribute
NormalizationSettings

Bases: BaseSettings

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
117
118
class NormalizationSettings(BaseSettings):
    norm_level: int = -1
norm_level = -1 class-attribute instance-attribute
RateSettings

Bases: BaseSettings

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
123
124
class RateSettings(BaseSettings):
    rate_args: list[str] = ["-v"]
rate_args = ['-v'] class-attribute instance-attribute
RemixSettings

Bases: BaseSettings

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
120
121
class RemixSettings(BaseSettings):
    remix_channels: str = "1,2"
remix_channels = '1,2' class-attribute instance-attribute
compress_wav_to_mp4_vbr(input_wav, output_path=None, quality=8)

Compress WAV to M4A (AAC VBR) using ffmpeg.

Parameters:

input_wav : str or Path Path to the input .wav file output_path : str or Path, optional Output .mp4 file path. If None, auto-generated from input quality : int, default=8 VBR quality level: 1 = good (~96kbps), 2 = very good (~128kbps), 3+ = higher bitrate

Returns:

Path Path to the compressed .m4a file

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def compress_wav_to_mp4_vbr(
    input_wav: str | Path, output_path: Optional[str | Path] = None, quality: int = 8
    ) -> Path:
    """
    Compress WAV to M4A (AAC VBR) using ffmpeg.

    Parameters:
    -----------
    input_wav : str or Path
        Path to the input .wav file
    output_path : str or Path, optional
        Output .mp4 file path. If None, auto-generated from input
    quality : int, default=8
        VBR quality level: 1 = good (~96kbps), 2 = very good (~128kbps), 3+ = higher bitrate

    Returns:
    --------
    Path
        Path to the compressed .m4a file
    """
    input_wav = Path(input_wav)
    if output_path is None:
        output_path = input_wav.with_suffix(".mp4")
    else:
        output_path = Path(output_path)

    cmd = [
        "ffmpeg", "-y", "-i", str(input_wav),
        "-c:a", "aac",
        "-q:a", str(quality),
        str(output_path)
    ]

    result = subprocess.run(cmd, capture_output=True, text=True)

    if result.returncode != 0:
        logger.error("Error compressing audio:")
        logger.error(result.stderr)
        raise RuntimeError("FFmpeg compression failed.")

    print(f"Compressed audio saved to: {output_path}")
    return output_path
get_sox_info(file_path)

Get audio info using SoX

Source code in src/tnh_scholar/audio_processing/utils/audio_enhance.py
411
412
413
414
415
416
417
418
def get_sox_info(file_path):
    """Get audio info using SoX"""
    result = subprocess.run(["sox", "--info", str(file_path)], 
                          capture_output=True, text=True)
    if result.returncode == 0:
        logger.error(result.stdout)
    else:
        logger.error(f"Error: {result.stderr}")
playback
get_audio_from_file(audio_file)
Source code in src/tnh_scholar/audio_processing/utils/playback.py
22
23
24
def get_audio_from_file(audio_file: Path) -> AudioSegment:
    audio_format_ext = audio_file.suffix.lstrip(".").lower()
    return AudioSegment.from_file(audio_file, format=audio_format_ext)
get_segment_audio(segment, audio)
Source code in src/tnh_scholar/audio_processing/utils/playback.py
16
17
def get_segment_audio(segment: DiarizedSegment, audio: AudioSegment):
    return audio[segment.start:segment.end]
play_audio_segment(audio)
Source code in src/tnh_scholar/audio_processing/utils/playback.py
19
20
def play_audio_segment(audio: AudioSegment):
    play(audio)
play_bytes(data, format='wav')
Source code in src/tnh_scholar/audio_processing/utils/playback.py
30
31
32
def play_bytes(data: BytesIO, format: str = "wav"):
    audio = AudioSegment.from_file(data, format=format)
    play(audio)
play_diarization_segment(segment, audio)
Source code in src/tnh_scholar/audio_processing/utils/playback.py
13
14
def play_diarization_segment(segment: DiarizedSegment, audio: AudioSegment):
    play_audio_segment(audio[segment.start:segment.end]) 
play_from_file(path)
Source code in src/tnh_scholar/audio_processing/utils/playback.py
26
27
28
def play_from_file(path: Path):
    audio = AudioSegment.from_file(path)
    play(audio)

whisper_security

logger = get_child_logger(__name__) module-attribute
load_whisper_model(model_name)

Safely load a Whisper model with security best practices.

Parameters:

Name Type Description Default
model_name str

Name of the Whisper model to load (e.g., "tiny", "base", "small")

required

Returns:

Type Description
Any

Loaded Whisper model

Raises:

Type Description
RuntimeError

If model loading fails

Source code in src/tnh_scholar/audio_processing/whisper_security.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def load_whisper_model(model_name: str) -> Any:
    """
    Safely load a Whisper model with security best practices.

    Args:
        model_name: Name of the Whisper model to load (e.g., "tiny", "base", "small")

    Returns:
        Loaded Whisper model

    Raises:
        RuntimeError: If model loading fails
    """
    import whisper

    try:
        with safe_torch_load():
            model = whisper.load_model(model_name)
        return model
    except Exception as e:
        logger.error("Failed to load Whisper model %r: %s", model_name, e)
        raise RuntimeError(f"Failed to load Whisper model: {e}") from e
safe_torch_load(weights_only=True)

Context manager that temporarily modifies torch.load to use weights_only=True by default.

This addresses the FutureWarning in PyTorch regarding pickle security: https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models

Parameters:

Name Type Description Default
weights_only bool

If True, limits unpickling to tensor data only.

True

Yields:

Type Description
None

None

Example

with safe_torch_load(): ... model = whisper.load_model("tiny")

Source code in src/tnh_scholar/audio_processing/whisper_security.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
@contextlib.contextmanager
def safe_torch_load(weights_only: bool = True) -> Generator[None, None, None]:
    """
    Context manager that temporarily modifies torch.load 
    to use weights_only=True by default.

    This addresses the FutureWarning in PyTorch regarding pickle security:
    https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models

    Args:
        weights_only: If True, limits unpickling to tensor data only.

    Yields:
        None

    Example:
        >>> with safe_torch_load():
        ...     model = whisper.load_model("tiny")
    """
    original_torch_load = torch.load
    try:
        torch.load = partial(original_torch_load, weights_only=weights_only)
        logger.debug("Modified torch.load to use weights_only=%s", weights_only)
        yield
    finally:
        torch.load = original_torch_load
        logger.debug("Restored original torch.load")

cli_tools

TNH Scholar CLI Tools

Command-line interface tools for the TNH Scholar project:

audio-transcribe:
    Audio processing pipeline that handles downloading, segmentation,
    and transcription of Buddhist teachings.

tnh-fab:
    Text processing tool for texts, providing functionality for
    punctuation, sectioning, translation, and pattern-based processing.

See individual tool documentation for usage details and examples.

audio_transcribe

audio_transcribe

CLI tool for downloading audio (YouTube or local), and transcribing to text.

Usage

audio-transcribe [OPTIONS]

e.g. audio-transcribe --yt_url https://www.youtube.com/watch?v=EXAMPLE --output_dir ./processed --service whisper --model whisper-1

DEFAULT_CHUNK_DURATION = 120 module-attribute
DEFAULT_MIN_CHUNK = 10 module-attribute
DEFAULT_MODEL = 'whisper-1' module-attribute
DEFAULT_OUTPUT_PATH = './audio_transcriptions/transcript.txt' module-attribute
DEFAULT_RESPONSE_FORMAT = 'text' module-attribute
DEFAULT_SERVICE = 'whisper' module-attribute
DEFAULT_TEMP_DIR = tempfile.gettempdir() module-attribute
VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.wmv'} module-attribute
logger = get_child_logger(__name__) module-attribute
AudioTranscribeApp

Main application class for audio transcription CLI.

Organizes configuration, source resolution, and pipeline execution.

Parameters:

Name Type Description Default
yt_url

YouTube URL to download audio from.

required
yt_url_csv

CSV file containing YouTube URLs.

required
file_

Path to local audio file.

required
output_dir

Directory for output files.

required
service

Transcription service provider.

required
model

Transcription model name.

required
language

Language code for transcription.

required
response_format

Format of transcription response.

required
chunk_duration

Target chunk duration (seconds).

required
min_chunk

Minimum chunk duration (seconds).

required
start_time

Start time offset (HH:MM:SS).

required
end_time

End time offset (HH:MM:SS).

required
prompt

Prompt or keywords for transcription.

required
Source code in src/tnh_scholar/cli_tools/audio_transcribe/audio_transcribe.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
class AudioTranscribeApp:
    """
    Main application class for audio transcription CLI.

    Organizes configuration, source resolution, and pipeline execution.

    Args:
        yt_url: YouTube URL to download audio from.
        yt_url_csv: CSV file containing YouTube URLs.
        file_: Path to local audio file.
        output_dir: Directory for output files.
        service: Transcription service provider.
        model: Transcription model name.
        language: Language code for transcription.
        response_format: Format of transcription response.
        chunk_duration: Target chunk duration (seconds).
        min_chunk: Minimum chunk duration (seconds).
        start_time: Start time offset (HH:MM:SS).
        end_time: End time offset (HH:MM:SS).
        prompt: Prompt or keywords for transcription.
    """
    def __init__(self, config: AudioTranscribeConfig) -> None:
        """
        Args:
            config: Validated AudioTranscribeConfig instance.
        """
        self.config = config
        self.yt_url = config.yt_url
        self.yt_url_csv = config.yt_url_csv
        self.file_ = config.file_
        self.output_path = Path(config.output)
        self.keep_artifacts = config.keep_artifacts
        # Use output directory for all artifacts if keep_artifacts is True, else use system temp
        if self.keep_artifacts:
            self.temp_dir = self.output_path.parent
        else:
            self.temp_dir = Path(tempfile.mkdtemp(dir=DEFAULT_TEMP_DIR))
        self.service = config.service
        self.model = config.model
        self.language = config.language
        self.response_format = config.response_format
        self.chunk_duration = TimeMs.from_seconds(config.chunk_duration)
        self.min_chunk = TimeMs.from_seconds(config.min_chunk)
        self.start_time = config.start_time
        self.end_time = config.end_time
        self.prompt = config.prompt
        ensure_directory_exists(self.output_path.parent)
        ensure_directory_exists(self.temp_dir)
        self.audio_file: Path = self._resolve_audio_source()
        self.transcription_options: dict = self._build_transcription_options()
        self.diarization_config = self._build_diarization_config()

    def run(self) -> None:
        """
        Run the transcription pipeline and print results, or just download audio if no_transcribe is set.
        """
        if self.config.no_transcribe:
            self._echo_settings()
            click.echo("\n[Download Only Mode]")
            click.echo(f"Downloaded audio file: {self.audio_file}")
            return
        pipeline = TranscriptionPipeline(
            audio_file=self.audio_file,
            output_dir=self.temp_dir,
            diarization_config=self.diarization_config,
            transcriber=self.service,
            transcription_options=self.transcription_options,
        )
        self._echo_settings()
        transcripts: list[str] = pipeline.run()
        self._write_transcript(transcripts)
        self._print_transcripts(transcripts)
        self._cleanup_temp_dir()

    def _cleanup_temp_dir(self) -> None:
        """
        Remove temp directory if not keeping artifacts.
        """
        if not self.keep_artifacts and self.temp_dir and self.temp_dir.exists():
            import shutil
            try:
                shutil.rmtree(self.temp_dir)
            except Exception as e:
                logger.warning(f"Failed to clean up temp directory: {self.temp_dir} ({e})")

    def _write_transcript(self, transcripts: list[str]) -> None:
        """
        Write the full transcript to the output file.

        Args:
            transcripts: List of transcript strings.
        """
        with open(self.output_path, "w", encoding="utf-8") as f:
            for chunk in transcripts:
                f.write(chunk.strip() + "\n\n")

    def _echo_settings(self) -> None:
        """
        Display all runtime settings except transcription_options and diarization_config.
        """
        click.echo("\n[Settings]")
        click.echo(f"  YouTube URL:         {self.yt_url}")
        click.echo(f"  YouTube CSV:         {self.yt_url_csv}")
        click.echo(f"  File:                {self.file_}")
        click.echo(f"  Output Path:         {self.output_path}")
        click.echo(f"  Temp Directory:      {self.temp_dir}")
        click.echo(f"  Service:             {self.service}")
        click.echo(f"  Model:               {self.model}")
        click.echo(f"  Language:            {self.language}")
        click.echo(f"  Response Format:     {self.response_format}")
        click.echo(f"  Chunk Duration:      {self.chunk_duration.to_seconds()} sec")
        click.echo(f"  Min Chunk:           {self.min_chunk.to_seconds()} sec")
        click.echo(f"  Start Time:          {self.start_time}")
        click.echo(f"  End Time:            {self.end_time}")
        click.echo(f"  Audio File:          {self.audio_file}")
        click.echo(f"  Prompt:              '{self.prompt}'")


    def _resolve_audio_source(self) -> Path:
        """
        Resolve and return the audio file to transcribe.

        Returns:
            Path: Path to the resolved audio file.
        Raises:
            FileNotFoundError: If no audio input is found.
            RuntimeError: If youtube-dl version check fails.
        """
        click.echo("[Resolving/Downloading Audio Source ...]")
        if self.yt_url_csv:
            self._set_yt_url_from_csv()
        if self.yt_url:
            return self._get_audio_from_youtube()
        if self.file_:
            return self._get_audio_from_file()
        logger.error("No audio input found.")
        raise FileNotFoundError("No audio input found.")

    def _set_yt_url_from_csv(self) -> None:
        """
        Set the YouTube URL from the first entry in the CSV file.
        """
        assert self.yt_url_csv
        urls: list[str] = get_youtube_urls_from_csv(Path(self.yt_url_csv))
        self.yt_url = urls[0] if urls else None

    def _get_audio_from_youtube(self) -> Path:
        """
        Download and return the audio file from YouTube.
        """
        if not check_ytd_version():
            logger.error("youtube-dl version check failed.")
            raise RuntimeError("youtube-dl version check failed.")

        dl = DLPDownloader()

        assert self.yt_url
        url_metadata = dl.get_metadata(self.yt_url)
        default_name = dl.get_default_filename_stem(url_metadata)
        download_path: Path = self.temp_dir / default_name
        download_file: Path = download_path.with_suffix(".mp3")
        if not download_file.exists():
            return self._extract_yt_audio(dl, download_path)
        click.echo(f"Re-using existing downloaded audio file: {download_file}")
        return download_file

    def _extract_yt_audio(self, dl, download_path):
        video_data = dl.get_audio(
                self.yt_url,
                start=self.start_time,
                output_path=download_path,
            )
        if not video_data or not video_data.filepath:
            raise FileNotFoundError("Failed to download or locate audio file.")
        return Path(video_data.filepath)

    def _get_audio_from_file(self) -> Path:
        """
        Return the audio file path, converting video if needed.
        """
        assert self.file_
        audio_file: Path = Path(self.file_)
        if audio_file.suffix.lower() in VIDEO_EXTENSIONS:
            logger.info(f"Detected video file: {audio_file}. Auto-converting to mp3 ...")
            return convert_video_to_audio(audio_file, self.temp_dir)
        return audio_file

    def _build_transcription_options(self) -> dict:
        """
        Build transcription options dictionary for the pipeline.

        Returns:
            dict: Transcription options for the pipeline.
        """
        options: dict = {
            "model": self.model,
            "language": self.language,
            "response_format": self.response_format,
            "prompt": self.prompt,
        }
        if self.service == "whisper" and self.response_format != "text":
            options["timestamp_granularities"] = ["word"]
        return options

    def _build_diarization_config(self) -> DiarizationConfig:
        """
        Build DiarizationConfig for chunking and language settings.

        Returns:
            DiarizationConfig: Configuration for diarization and chunking.
        """
        from tnh_scholar.audio_processing.diarization.config import (
            ChunkConfig,
            DiarizationConfig,
            LanguageConfig,
            SpeakerConfig,
        )
        return DiarizationConfig(
            chunk=ChunkConfig(
                target_duration=self.chunk_duration,
                min_duration=self.min_chunk,
            ),
            speaker=SpeakerConfig(single_speaker=True),
            language=LanguageConfig(default_language=self.language),
        )

    def _print_transcripts(self, transcripts: list[str]) -> None:
        """
        Print each transcript chunk to stdout.

        Args:
            transcripts: List of transcript strings.
        """
        for i, text in enumerate(transcripts, 1):
            print(f"\n--- Transcript chunk {i} ---\n{text}\n")
audio_file = self._resolve_audio_source() instance-attribute
chunk_duration = TimeMs.from_seconds(config.chunk_duration) instance-attribute
config = config instance-attribute
diarization_config = self._build_diarization_config() instance-attribute
end_time = config.end_time instance-attribute
file_ = config.file_ instance-attribute
keep_artifacts = config.keep_artifacts instance-attribute
language = config.language instance-attribute
min_chunk = TimeMs.from_seconds(config.min_chunk) instance-attribute
model = config.model instance-attribute
output_path = Path(config.output) instance-attribute
prompt = config.prompt instance-attribute
response_format = config.response_format instance-attribute
service = config.service instance-attribute
start_time = config.start_time instance-attribute
temp_dir = self.output_path.parent instance-attribute
transcription_options = self._build_transcription_options() instance-attribute
yt_url = config.yt_url instance-attribute
yt_url_csv = config.yt_url_csv instance-attribute
__init__(config)

Parameters:

Name Type Description Default
config AudioTranscribeConfig

Validated AudioTranscribeConfig instance.

required
Source code in src/tnh_scholar/cli_tools/audio_transcribe/audio_transcribe.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def __init__(self, config: AudioTranscribeConfig) -> None:
    """
    Args:
        config: Validated AudioTranscribeConfig instance.
    """
    self.config = config
    self.yt_url = config.yt_url
    self.yt_url_csv = config.yt_url_csv
    self.file_ = config.file_
    self.output_path = Path(config.output)
    self.keep_artifacts = config.keep_artifacts
    # Use output directory for all artifacts if keep_artifacts is True, else use system temp
    if self.keep_artifacts:
        self.temp_dir = self.output_path.parent
    else:
        self.temp_dir = Path(tempfile.mkdtemp(dir=DEFAULT_TEMP_DIR))
    self.service = config.service
    self.model = config.model
    self.language = config.language
    self.response_format = config.response_format
    self.chunk_duration = TimeMs.from_seconds(config.chunk_duration)
    self.min_chunk = TimeMs.from_seconds(config.min_chunk)
    self.start_time = config.start_time
    self.end_time = config.end_time
    self.prompt = config.prompt
    ensure_directory_exists(self.output_path.parent)
    ensure_directory_exists(self.temp_dir)
    self.audio_file: Path = self._resolve_audio_source()
    self.transcription_options: dict = self._build_transcription_options()
    self.diarization_config = self._build_diarization_config()
run()

Run the transcription pipeline and print results, or just download audio if no_transcribe is set.

Source code in src/tnh_scholar/cli_tools/audio_transcribe/audio_transcribe.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def run(self) -> None:
    """
    Run the transcription pipeline and print results, or just download audio if no_transcribe is set.
    """
    if self.config.no_transcribe:
        self._echo_settings()
        click.echo("\n[Download Only Mode]")
        click.echo(f"Downloaded audio file: {self.audio_file}")
        return
    pipeline = TranscriptionPipeline(
        audio_file=self.audio_file,
        output_dir=self.temp_dir,
        diarization_config=self.diarization_config,
        transcriber=self.service,
        transcription_options=self.transcription_options,
    )
    self._echo_settings()
    transcripts: list[str] = pipeline.run()
    self._write_transcript(transcripts)
    self._print_transcripts(transcripts)
    self._cleanup_temp_dir()
audio_transcribe(**kwargs)

CLI entry point for audio transcription.

Source code in src/tnh_scholar/cli_tools/audio_transcribe/audio_transcribe.py
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
@click.command()
@click.option(
    "-y", "--yt_url", type=str,
    help="Single YouTube URL."
)
@click.option(
    "-v", "--yt_url_csv", type=click.Path(exists=True),
    help="CSV file with multiple YouTube URLs in first column."
)
@click.option(
    "-f", "--file", "file_", type=click.Path(exists=True),
    help="Path to a local audio file."
)
@click.option(
    "-o", "--output", type=click.Path(), default=DEFAULT_OUTPUT_PATH,
    help="Path to the output transcript file."
)
    # Removed temp_dir option, now handled by keep_artifacts only
@click.option(
    "-s", "--service", type=click.Choice(["whisper", "assemblyai"]), default=DEFAULT_SERVICE,
    help="Transcription service to use."
)
@click.option(
    "-m", "--model", type=str, default=DEFAULT_MODEL,
    help="Model to use for transcription (for OpenAI only)."
)
@click.option(
    "-l", "--language", type=str, default="en",
    help="Language code (e.g., 'en', 'vi')."
)
@click.option(
    "-r", "--response_format", type=str, default=DEFAULT_RESPONSE_FORMAT,
    help="Response format for Whisper (default: text)."
)
@click.option(
    "--chunk_duration", type=int, default=DEFAULT_CHUNK_DURATION,
    help="Chunk duration in seconds (default: 7 minutes)."
)
@click.option(
    "--min_chunk", type=int, default=DEFAULT_MIN_CHUNK,
    help="Minimum chunk duration in seconds."
)
@click.option(
    "--start_time", type=str,
    help="Start time offset for the input media (HH:MM:SS)."
)
@click.option(
    "--end_time", type=str,
    help="End time offset for the input media (HH:MM:SS)."
)
@click.option(
    "--prompt", type=str, default="",
    help="Prompt or keywords to guide the transcription."
)
@click.option(
    "-n", "--no_transcribe", is_flag=True, default=False,
    help="Download YouTube audio to mp3 only, do not transcribe. Requires --yt_url or --yt_url_csv."
)
@click.option(
    "-k", "--keep_artifacts", is_flag=True, default=False,
    help="Keep all intermediate artifacts in the output directory instead of using a system temp directory."
)
def audio_transcribe(**kwargs):
    """
    CLI entry point for audio transcription.
    """
    try:
        config = AudioTranscribeConfig(**kwargs)
    except NoAudioSourceError as e:
        print(f"\n[INPUT ERROR] {e}", flush=True)
        raise SystemExit(1) from e
    except MultipleAudioSourceError as e:
        print(f"\n[INPUT ERROR] {e}", flush=True)
        raise SystemExit(1) from e
    except ValidationError as e:
        print("\n[CONFIG VALIDATION ERROR]\n", e, flush=True)
        raise SystemExit(1) from e
    app = AudioTranscribeApp(config)
    app.run()
main()
Source code in src/tnh_scholar/cli_tools/audio_transcribe/audio_transcribe.py
384
385
def main():
    audio_transcribe()
config
DEFAULT_OUTPUT_PATH = './audio_transcriptions/transcript.txt' module-attribute
DEFAULT_SERVICE = 'whisper' module-attribute
DEFAULT_TEMP_DIR = './audio_transcriptions/tmp' module-attribute
AudioTranscribeConfig

Bases: BaseSettings

Source code in src/tnh_scholar/cli_tools/audio_transcribe/config.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class AudioTranscribeConfig(BaseSettings):
    model_config = SettingsConfigDict(
        env_file=".env",
        env_file_encoding="utf-8",
        extra="ignore"
    )

    yt_url: Optional[str] = Field(default=None, description="YouTube URL")
    yt_url_csv: Optional[str] = Field(default=None, description="CSV file with YouTube URLs")
    file_: Optional[str] = Field(default=None, description="Path to local audio file")
    output: str = Field(default=DEFAULT_OUTPUT_PATH, description="Path to output transcript file")
    temp_dir: Optional[str] = Field(default=None, description="Directory for temporary processing files")
    service: str = Field(
        default=DEFAULT_SERVICE, pattern="^(whisper|assemblyai)$", description="Transcription service"
    )
    model: str = Field(description="Transcription model name")
    language: str = Field(default="en", description="Language code")
    response_format: str = Field(description="Response format")
    chunk_duration: int = Field(description="Target chunk duration in seconds")
    min_chunk: int = Field(ge=10, description="Minimum chunk duration in seconds")
    start_time: Optional[str] = Field(default=None, description="Start time offset")
    end_time: Optional[str] = Field(default=None, description="End time offset")
    prompt: str = Field(default="", description="Prompt or keywords")

    no_transcribe: bool = Field(default=False, 
                                description="If True, only download YouTube audio to mp3, no transcription.")
    keep_artifacts: bool = Field(default=False, 
                                 description="Keep all intermediate artifacts in the output directory "
                                 "instead of using a system temp directory.")

    @model_validator(mode="after")
    def validate_sources(self):
        sources = [self.yt_url, self.yt_url_csv, self.file_]
        num_sources = sum(bool(s) for s in sources)
        if self.no_transcribe:
            # Only allow yt_url or yt_url_csv for download-only mode
            if not (self.yt_url or self.yt_url_csv):
                raise ValidationError(
                    "--no_transcribe requires a YouTube URL or CSV (--yt_url or --yt_url_csv)."
                )
            if self.file_:
                raise ValidationError(
                    "--no_transcribe does not support local file input. Use --yt_url or --yt_url_csv only."
                )
        else:
            if num_sources == 0:
                raise NoAudioSourceError(
                    "No audio source provided: yt_url, yt_url_csv, or _file input"
                )
            if num_sources > 1:
                raise MultipleAudioSourceError(
                    "Only one audio source may be provided at a time: yt_url, yt_url_csv, or _file input."
                )
        return self
chunk_duration = Field(description='Target chunk duration in seconds') class-attribute instance-attribute
end_time = Field(default=None, description='End time offset') class-attribute instance-attribute
file_ = Field(default=None, description='Path to local audio file') class-attribute instance-attribute
keep_artifacts = Field(default=False, description='Keep all intermediate artifacts in the output directory instead of using a system temp directory.') class-attribute instance-attribute
language = Field(default='en', description='Language code') class-attribute instance-attribute
min_chunk = Field(ge=10, description='Minimum chunk duration in seconds') class-attribute instance-attribute
model = Field(description='Transcription model name') class-attribute instance-attribute
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', extra='ignore') class-attribute instance-attribute
no_transcribe = Field(default=False, description='If True, only download YouTube audio to mp3, no transcription.') class-attribute instance-attribute
output = Field(default=DEFAULT_OUTPUT_PATH, description='Path to output transcript file') class-attribute instance-attribute
prompt = Field(default='', description='Prompt or keywords') class-attribute instance-attribute
response_format = Field(description='Response format') class-attribute instance-attribute
service = Field(default=DEFAULT_SERVICE, pattern='^(whisper|assemblyai)$', description='Transcription service') class-attribute instance-attribute
start_time = Field(default=None, description='Start time offset') class-attribute instance-attribute
temp_dir = Field(default=None, description='Directory for temporary processing files') class-attribute instance-attribute
yt_url = Field(default=None, description='YouTube URL') class-attribute instance-attribute
yt_url_csv = Field(default=None, description='CSV file with YouTube URLs') class-attribute instance-attribute
validate_sources()
Source code in src/tnh_scholar/cli_tools/audio_transcribe/config.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@model_validator(mode="after")
def validate_sources(self):
    sources = [self.yt_url, self.yt_url_csv, self.file_]
    num_sources = sum(bool(s) for s in sources)
    if self.no_transcribe:
        # Only allow yt_url or yt_url_csv for download-only mode
        if not (self.yt_url or self.yt_url_csv):
            raise ValidationError(
                "--no_transcribe requires a YouTube URL or CSV (--yt_url or --yt_url_csv)."
            )
        if self.file_:
            raise ValidationError(
                "--no_transcribe does not support local file input. Use --yt_url or --yt_url_csv only."
            )
    else:
        if num_sources == 0:
            raise NoAudioSourceError(
                "No audio source provided: yt_url, yt_url_csv, or _file input"
            )
        if num_sources > 1:
            raise MultipleAudioSourceError(
                "Only one audio source may be provided at a time: yt_url, yt_url_csv, or _file input."
            )
    return self
MultipleAudioSourceError

Bases: ValueError

Raised when audio source selection has multiple sources).

Source code in src/tnh_scholar/cli_tools/audio_transcribe/config.py
12
13
class MultipleAudioSourceError(ValueError):
    """Raised when audio source selection has multiple sources)."""
NoAudioSourceError

Bases: ValueError

Raised when no audio source is provided.

Source code in src/tnh_scholar/cli_tools/audio_transcribe/config.py
 9
10
class NoAudioSourceError(ValueError):
    """Raised when no audio source is provided."""
convert_video
FFMPEG_VIDEO_CONV_DEFAULT_CONFIG = {'audio_codec': 'libmp3lame', 'audio_bitrate': '192k', 'audio_samplerate': '44100'} module-attribute
logger = get_child_logger(__name__) module-attribute
convert_video_to_audio(video_file, output_dir, conversion_params=None)

Convert a video file to an audio file using ffmpeg.

Parameters:

Name Type Description Default
video_file Path

Path to the video file

required
output_dir Path

Directory to save the converted audio file

required
conversion_params Optional[Dict[str, str]]

Optional dictionary to override default conversion parameters

None

Returns:

Type Description
Path

Path to the converted audio file

Source code in src/tnh_scholar/cli_tools/audio_transcribe/convert_video.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def convert_video_to_audio(
    video_file: Path, 
    output_dir: Path,
    conversion_params: Optional[Dict[str, str]] = None
) -> Path:
    """
    Convert a video file to an audio file using ffmpeg.

    Args:
        video_file: Path to the video file
        output_dir: Directory to save the converted audio file
        conversion_params: Optional dictionary to override default conversion parameters

    Returns:
        Path to the converted audio file
    """
    output_file = output_dir / f"{video_file.stem}.mp3"

    if output_file.exists():
        logger.info(f"Audio file already exists: {output_file}")
        return output_file

    # Merge default config with any supplied parameters
    params = {**FFMPEG_VIDEO_CONV_DEFAULT_CONFIG}
    if conversion_params:
        params |= conversion_params

    logger.info(f"Converting video to audio: {video_file} -> {output_file}")
    logger.debug(f"Using conversion parameters: {params}")

    try:
        cmd = [
            "ffmpeg", 
            "-i", str(video_file),
            "-vn",
            "-acodec", params["audio_codec"],
            "-ab", params["audio_bitrate"],
            "-ar", params["audio_samplerate"],
            "-y",  # Overwrite output file if it exists
            str(output_file)
        ]
        subprocess.run(cmd, check=True, capture_output=True)
        logger.info(f"Conversion successful: {output_file}")
        return output_file
    except subprocess.CalledProcessError as e:
        logger.error(f"Conversion failed: {e.stderr.decode() if e.stderr else str(e)}")
        raise RuntimeError(f"Failed to convert video: {video_file}") from e
environment
env
logger = get_child_logger(__name__) module-attribute
check_env()

Check the environment for necessary conditions: 1. Check OpenAI key is available. 2. Check that all requirements from requirements.txt are importable.

Source code in src/tnh_scholar/cli_tools/audio_transcribe/environment/env.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def check_env() -> bool:
    """
    Check the environment for necessary conditions:
    1. Check OpenAI key is available.
    2. Check that all requirements from requirements.txt are importable.
    """
    logger.debug("checking environment.")

    if not check_openai_env():
        return False

    if shutil.which("ffmpeg") is None:
        logger.error("ffmpeg not found in PATH. ffmpeg required for audio processing.")
        return False

    return True
check_requirements(requirements_file)

Check that all requirements listed in requirements.txt can be imported. If any cannot be imported, print a warning.

This is a heuristic check. Some packages may not share the same name as their importable module. Adjust the name mappings below as needed.

Example

check_requirements(Path("./requirements.txt"))

Prints warnings if imports fail, otherwise silent.
Source code in src/tnh_scholar/cli_tools/audio_transcribe/environment/env.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def check_requirements(requirements_file: Path) -> None:
    """
    Check that all requirements listed in requirements.txt can be imported.
    If any cannot be imported, print a warning.

    This is a heuristic check. Some packages may not share the same name as their importable module.
    Adjust the name mappings below as needed.

    Example:
        >>> check_requirements(Path("./requirements.txt"))
        # Prints warnings if imports fail, otherwise silent.
    """
    # Map requirement names to their importable module names if they differ
    name_map = {
        "python-dotenv": "dotenv",
        "openai_whisper": "whisper",
        "protobuf": "google.protobuf",
        # Add other mappings if needed
    }

    # Parse requirements.txt to get a list of package names
    packages = []
    with requirements_file.open("r") as req_file:
        for line in req_file:
            line = line.strip()
            if not line or line.startswith("#"):
                continue
            # Each line generally looks like 'package==version'
            pkg_name = line.split("==")[0].strip()
            packages.append(pkg_name)

    # Try importing each package
    for pkg in packages:
        mod_name = name_map.get(pkg, pkg)
        try:
            __import__(mod_name)
        except ImportError:
            print(
                f"WARNING: Could not import '{mod_name}' from '{pkg}'. Check that it is correctly installed."
            )
transcription_pipeline
TranscriptionPipeline
Source code in src/tnh_scholar/cli_tools/audio_transcribe/transcription_pipeline.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
class TranscriptionPipeline:
    def __init__(
        self,
        audio_file: Path,
        output_dir: Path,
        diarization_config: Optional[DiarizationConfig] = None,
        transcriber: str = "whisper",
        transcription_options: Optional[Dict[str, Any]] = None,
        diarization_kwargs: Optional[Dict[str, Any]] = None,
        save_diarization: bool = True,
        logger: Optional[logging.Logger] = None,
    ):
        """
        Initialize the TranscriptionPipeline.

        Args:
            audio_file (Path): Path to the audio file to process.
            output_dir (Path): Directory to store output files.
            diarization_config (Optional[DiarizationConfig]): Diarization configuration.
            transcriber (str): Transcription service provider.
            transcription_options (Optional[Dict[str, Any]]): Options for transcription.
            diarization_kwargs (Optional[Dict[str, Any]]): Additional diarization arguments.
            save_diarization_json (bool): Whether to save raw diarization JSON results.
            logger (Optional[logging.Logger]): Logger for pipeline events.
        """
        self.logger = logger or logging.getLogger(__name__)
        self._validate_audio_file(audio_file)
        self._validate_output_dir(output_dir)

        self.audio_file = audio_file
        self.output_dir = output_dir
        self.diarization_config = diarization_config or DiarizationConfig()
        self.transcriber = transcriber
        if transcriber == "whisper":
            self.transcription_options = patch_whisper_options(
                transcription_options,
                file_extension=audio_file.suffix
            )
        else:
            self.transcription_options = transcription_options
        self.diarization_kwargs = diarization_kwargs or {}
        self.save_diarization = save_diarization

        if self.save_diarization:
            self.diarization_dir = self.output_dir / f"{self.audio_file.stem}_diarization"
            self.diarization_results_path = self.diarization_dir / "raw_diarization_results.json"
        else:
            self.diarization_dir = None
            self.diarization_results_path = None

        ensure_directory_writable(self.output_dir)
        if self.save_diarization:
            assert self.diarization_dir
            ensure_directory_writable(self.diarization_dir)

        self.audio_file_extension = audio_file.suffix

    def _validate_audio_file(self, audio_file: Path | str) -> None:
        """
        Validate the audio file input.

        Args:
            audio_file (Union[str, Path]): Path to the audio file.

        Raises:
            TypeError: If not a str or Path instance.
            FileNotFoundError: If file does not exist.
        """
        if isinstance(audio_file, str):
            audio_file = Path(audio_file)
        elif not isinstance(audio_file, Path):
            raise TypeError("audio_file must be a str or pathlib.Path instance")
        if not audio_file.exists() or not audio_file.is_file():
            raise FileNotFoundError(f"Audio file does not exist: {audio_file}")

    def _validate_output_dir(self, output_dir: Path | str) -> None:
        """
        Validate the output directory

        Args:
            output_dir (Path | str): Path to the output directory.

        Raises:
            TypeError: If not a Path or str instance.
        """
        if isinstance(output_dir, str):
            output_dir = Path(output_dir)
        elif not isinstance(output_dir, Path):
            raise TypeError("output_dir must be a str or pathlib.Path instance")


    def run(self) -> Optional[List[Dict[str, Any]]]:
        """
        Execute the full transcription pipeline with robust error handling.

        Returns:
            List[Dict[str, Any]]: List of transcript dicts with chunk metadata, or None on failure

        Raises:
            RuntimeError: If any pipeline step fails.
        """
        try:
            self.logger.info("Starting diarization step.")
            segments = self._run_diarization()
            if not segments:
                self.logger.warning("No diarization segments found.")
                return []
            self.logger.info("Chunking segments.")
            chunk_list = self._chunk_segments(segments)
            if not chunk_list:
                self.logger.warning("No chunks produced from segments.")
                return []
            self.logger.info("Extracting audio chunks.")
            self._extract_audio_chunks(chunk_list)
            self.logger.info("Transcribing chunks.")
            return self._transcribe_chunks(chunk_list)
        except Exception as exc:
            self._handle_pipeline_error(exc)
            return None

    def _run_diarization(self) -> List[Any]:
        """
        Orchestrate diarization and return domain-level segments.
        Uses structural pattern matching on the discriminated union.
        """
        # local import to avoid cycles
        from tnh_scholar.audio_processing.diarization import diarize, diarize_to_file

        if self.save_diarization:
            response: DiarizationResponse = diarize_to_file(
                audio_file_path=self.audio_file,
                output_path=self.diarization_results_path,
                wait_until_complete=True, # for this module defaulting to unlimited processing time
                **(self.diarization_kwargs or {})
            )
        else:
            response: DiarizationResponse = diarize(
                self.audio_file,
                wait_until_complete=True,
                **(self.diarization_kwargs or {})
            )
        if response is None:
            raise RuntimeError("Diarizer returned None response")

        # Discriminated-union matching
        match response:
            case DiarizationSucceeded(output=out):
                segments = getattr(out, "segments", None)
                if segments is None:
                    raise RuntimeError("DiarizationSucceeded missing 'segments'")
                self.logger.info(f"Diarization succeeded: {len(segments)} segments.")
                return segments

            case DiarizationFailed(error=err):
                raise RuntimeError(f"Diarization failed: {getattr(err, 'message', err)}")

            case DiarizationPending() | DiarizationRunning():
                raise RuntimeError("Diarization incomplete (pending/running).")

            case _:
                self.logger.error("Unhandled diarization response variant: "
                                  f"{type(response).__name__} - {response!r}"
                                  )
                raise RuntimeError("Unhandled diarization response variant")

    def _chunk_segments(self, segments):
        """
        Chunk diarization segments with error handling.
        """
        try:
            chunker = TimeGapChunker(config=self.diarization_config)
            chunks = chunker.extract(segments)
            if not chunks:
                self.logger.warning("No chunks produced from segments.")
            return chunks
        except Exception as exc:
            self.logger.error(f"Chunking segments failed: {exc}")
            raise RuntimeError(f"Chunking segments failed: {exc}") from exc

    def _extract_audio_chunks(self, chunk_list):
        """
        Extract audio chunks with error handling.
        Remove failed chunks from the list and add error metadata for traceability.
        """
        audio_handler = AudioHandler()
        successful_chunks = []
        for chunk in chunk_list:
            try:
                audio_handler.build_audio_chunk(chunk, audio_file=self.audio_file)
                successful_chunks.append(chunk)
            except Exception as exc:
                self.logger.error(f"Audio chunk extraction failed for chunk {chunk}: {exc}")
                # Do not add chunk to successful_chunks, effectively removing it from further processing
        # Update chunk_list in place to only include successful chunks
        chunk_list[:] = successful_chunks

    def _transcribe_chunks(self, chunk_list):
        """
        Transcribe audio chunks with error handling.
        """
        ts_service = TranscriptionServiceFactory.create_service(provider=self.transcriber)
        transcripts = []
        for chunk in chunk_list:
            transcript_text = None
            error_detail = None
            try:
                audio = chunk.audio
                if not audio:
                    self.logger.warning(f"No audio data for chunk {chunk}. Skipping transcription.")
                    continue
                audio_obj = audio.data
                transcript = ts_service.transcribe(
                    audio_obj,
                    self.transcription_options,
                )
                transcript_text = transcript.text
                error_detail = None
            except Exception as exc:
                self.logger.error(f"Transcription failed for chunk {chunk}: {exc}")
                transcript_text = None
                error_detail = str(exc)
            transcripts.append({
                "chunk": chunk,
                "transcript": transcript_text,
                "error": error_detail
            })
        return transcripts

    def _handle_pipeline_error(self, exc: Exception) -> None:
        """
        Handle pipeline errors in a modular way.

        Args:
            exc (Exception): The exception to handle.

        Raises:
            RuntimeError: Always re-raises the error after logging.
        """
        self.logger.error(f"TranscriptionPipeline failed: {exc}")
        raise RuntimeError(f"TranscriptionPipeline failed: {exc}") from exc
audio_file = audio_file instance-attribute
audio_file_extension = audio_file.suffix instance-attribute
diarization_config = diarization_config or DiarizationConfig() instance-attribute
diarization_dir = self.output_dir / f'{self.audio_file.stem}_diarization' instance-attribute
diarization_kwargs = diarization_kwargs or {} instance-attribute
diarization_results_path = self.diarization_dir / 'raw_diarization_results.json' instance-attribute
logger = logger or logging.getLogger(__name__) instance-attribute
output_dir = output_dir instance-attribute
save_diarization = save_diarization instance-attribute
transcriber = transcriber instance-attribute
transcription_options = patch_whisper_options(transcription_options, file_extension=(audio_file.suffix)) instance-attribute
__init__(audio_file, output_dir, diarization_config=None, transcriber='whisper', transcription_options=None, diarization_kwargs=None, save_diarization=True, logger=None)

Initialize the TranscriptionPipeline.

Parameters:

Name Type Description Default
audio_file Path

Path to the audio file to process.

required
output_dir Path

Directory to store output files.

required
diarization_config Optional[DiarizationConfig]

Diarization configuration.

None
transcriber str

Transcription service provider.

'whisper'
transcription_options Optional[Dict[str, Any]]

Options for transcription.

None
diarization_kwargs Optional[Dict[str, Any]]

Additional diarization arguments.

None
save_diarization_json bool

Whether to save raw diarization JSON results.

required
logger Optional[Logger]

Logger for pipeline events.

None
Source code in src/tnh_scholar/cli_tools/audio_transcribe/transcription_pipeline.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __init__(
    self,
    audio_file: Path,
    output_dir: Path,
    diarization_config: Optional[DiarizationConfig] = None,
    transcriber: str = "whisper",
    transcription_options: Optional[Dict[str, Any]] = None,
    diarization_kwargs: Optional[Dict[str, Any]] = None,
    save_diarization: bool = True,
    logger: Optional[logging.Logger] = None,
):
    """
    Initialize the TranscriptionPipeline.

    Args:
        audio_file (Path): Path to the audio file to process.
        output_dir (Path): Directory to store output files.
        diarization_config (Optional[DiarizationConfig]): Diarization configuration.
        transcriber (str): Transcription service provider.
        transcription_options (Optional[Dict[str, Any]]): Options for transcription.
        diarization_kwargs (Optional[Dict[str, Any]]): Additional diarization arguments.
        save_diarization_json (bool): Whether to save raw diarization JSON results.
        logger (Optional[logging.Logger]): Logger for pipeline events.
    """
    self.logger = logger or logging.getLogger(__name__)
    self._validate_audio_file(audio_file)
    self._validate_output_dir(output_dir)

    self.audio_file = audio_file
    self.output_dir = output_dir
    self.diarization_config = diarization_config or DiarizationConfig()
    self.transcriber = transcriber
    if transcriber == "whisper":
        self.transcription_options = patch_whisper_options(
            transcription_options,
            file_extension=audio_file.suffix
        )
    else:
        self.transcription_options = transcription_options
    self.diarization_kwargs = diarization_kwargs or {}
    self.save_diarization = save_diarization

    if self.save_diarization:
        self.diarization_dir = self.output_dir / f"{self.audio_file.stem}_diarization"
        self.diarization_results_path = self.diarization_dir / "raw_diarization_results.json"
    else:
        self.diarization_dir = None
        self.diarization_results_path = None

    ensure_directory_writable(self.output_dir)
    if self.save_diarization:
        assert self.diarization_dir
        ensure_directory_writable(self.diarization_dir)

    self.audio_file_extension = audio_file.suffix
run()

Execute the full transcription pipeline with robust error handling.

Returns:

Type Description
Optional[List[Dict[str, Any]]]

List[Dict[str, Any]]: List of transcript dicts with chunk metadata, or None on failure

Raises:

Type Description
RuntimeError

If any pipeline step fails.

Source code in src/tnh_scholar/cli_tools/audio_transcribe/transcription_pipeline.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def run(self) -> Optional[List[Dict[str, Any]]]:
    """
    Execute the full transcription pipeline with robust error handling.

    Returns:
        List[Dict[str, Any]]: List of transcript dicts with chunk metadata, or None on failure

    Raises:
        RuntimeError: If any pipeline step fails.
    """
    try:
        self.logger.info("Starting diarization step.")
        segments = self._run_diarization()
        if not segments:
            self.logger.warning("No diarization segments found.")
            return []
        self.logger.info("Chunking segments.")
        chunk_list = self._chunk_segments(segments)
        if not chunk_list:
            self.logger.warning("No chunks produced from segments.")
            return []
        self.logger.info("Extracting audio chunks.")
        self._extract_audio_chunks(chunk_list)
        self.logger.info("Transcribing chunks.")
        return self._transcribe_chunks(chunk_list)
    except Exception as exc:
        self._handle_pipeline_error(exc)
        return None
validate
validate_inputs(is_download, yt_url, yt_url_list, audio_file, split, transcribe, chunk_dir, no_chunks, silence_boundaries, whisper_boundaries)

Validate the CLI inputs to ensure logical consistency given all the flags.

Conditions & Requirements: 1. At least one action (yt_download, split, transcribe) should be requested. Otherwise, nothing is done, so raise an error.

  1. If yt_download is True:
  2. Must specify either yt_process_url OR yt_process_url_list (not both, not none).

  3. If yt_download is False:

  4. If split is requested, we need a local audio file (since no download will occur).
  5. If transcribe is requested without split and without yt_download:

    • If no_chunks = False, we must have chunk_dir to read existing chunks.
    • If no_chunks = True, we must have a local audio file (direct transcription) or previously downloaded file (but since yt_download=False, previously downloaded file scenario doesn't apply here, so effectively we need local audio in that scenario).
  6. no_chunks flag:

  7. If no_chunks = True, we are doing direct transcription on entire audio without chunking.

    • Cannot use split if no_chunks = True. (Mutually exclusive)
    • chunk_dir is irrelevant if no_chunks = True; since we don't split into chunks, requiring a chunk_dir doesn't make sense. If provided, it's not useful, but let's allow it silently or raise an error for clarity. It's safer to raise an error to prevent user confusion.
  8. Boundaries flags (silence_boundaries, whisper_boundaries):

  9. These flags control how splitting is done.
  10. If split = False, these are irrelevant. Not necessarily an error, but could be a no-op. For robustness, raise an error if user specifies these without split, to avoid confusion.
  11. If split = True and no_chunks = True, that’s contradictory already, so no need for boundary logic there.
  12. If split = True, exactly one method should be chosen: If both silence_boundaries and whisper_boundaries are True simultaneously or both are False simultaneously, we need a clear default or raise an error. By the code snippet logic, whisper_boundaries is default True if not stated otherwise. To keep it robust:
    • If both are True, raise error.
    • If both are False, that means user explicitly turned them off or never turned on whisper. The code snippet sets whisper_boundaries True by default. If user sets it False somehow, we can then default to silence. Just ensure at run-time we have a deterministic method: If both are False, we can default to whisper or silence. Let's default to whisper if no flags given. However, given the code snippet, whisper_boundaries has a default of True. If the user sets whisper_boundaries to False and also does not set silence_boundaries, then no method is chosen. Let's then raise an error if both ended up False to avoid ambiguity.

Raises:

Type Description
ValueError

If the input arguments are not logically consistent.

Source code in src/tnh_scholar/cli_tools/audio_transcribe/validate.py
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def validate_inputs(
    is_download: bool,
    yt_url: str | None,
    yt_url_list: Path | None,
    audio_file: Path | None,
    split: bool,
    transcribe: bool,
    chunk_dir: Path | None,
    no_chunks: bool,
    silence_boundaries: bool,
    whisper_boundaries: bool,
) -> None:
    """
    Validate the CLI inputs to ensure logical consistency given all the flags.

    Conditions & Requirements:
    1. At least one action (yt_download, split, transcribe) should be requested.
       Otherwise, nothing is done, so raise an error.

    2. If yt_download is True:
       - Must specify either yt_process_url OR yt_process_url_list (not both, not none).

    3. If yt_download is False:
       - If split is requested, we need a local audio file (since no download will occur).
       - If transcribe is requested without split and without yt_download:
         - If no_chunks = False, we must have chunk_dir to read existing chunks.
         - If no_chunks = True, we must have a local audio file (direct transcription) or previously downloaded file
           (but since yt_download=False, previously downloaded file scenario doesn't apply here,
           so effectively we need local audio in that scenario).

    4. no_chunks flag:
       - If no_chunks = True, we are doing direct transcription on entire audio without chunking.
         - Cannot use split if no_chunks = True. (Mutually exclusive)
         - chunk_dir is irrelevant if no_chunks = True; since we don't split into chunks,
           requiring a chunk_dir doesn't make sense. If provided, it's not useful, but let's allow it silently
           or raise an error for clarity. It's safer to raise an error to prevent user confusion.

    5. Boundaries flags (silence_boundaries, whisper_boundaries):
       - These flags control how splitting is done.
       - If split = False, these are irrelevant. Not necessarily an error, but could be a no-op.
         For robustness, raise an error if user specifies these without split, to avoid confusion.
       - If split = True and no_chunks = True, that’s contradictory already, so no need for boundary logic there.
       - If split = True, exactly one method should be chosen:
         If both silence_boundaries and whisper_boundaries are True simultaneously or both are False simultaneously,
         we need a clear default or raise an error. By the code snippet logic, whisper_boundaries is default True
         if not stated otherwise. To keep it robust:
           - If both are True, raise error.
           - If both are False, that means user explicitly turned them off or never turned on whisper.
             The code snippet sets whisper_boundaries True by default. If user sets it False somehow,
             we can then default to silence. Just ensure at run-time we have a deterministic method:
             If both are False, we can default to whisper or silence. Let's default to whisper if no flags given.
             However, given the code snippet, whisper_boundaries has a default of True.
             If the user sets whisper_boundaries to False and also does not set silence_boundaries,
             then no method is chosen. Let's then raise an error if both ended up False to avoid ambiguity.

    Raises:
        ValueError: If the input arguments are not logically consistent.
    """

    # 1. Check that we have at least one action
    if not is_download and not split and not transcribe:
        raise ValueError(
            "No actions requested. At least one of --yt_download, --split, --transcribe, or --full must be set."
        )

    # 2. Validate YouTube download logic
    if is_download:
        if yt_url and yt_url_list:
            raise ValueError(
                "Both --yt_process_url and --yt_process_url_list provided. Only one allowed."
            )
        if not yt_url and not yt_url_list:
            raise ValueError(
                "When --yt_download is specified, you must provide --yt_process_url or --yt_process_url_list."
            )

    # 3. Logic when no YouTube download:
    if not is_download:
        # If splitting but no download, need an audio file
        if split and audio_file is None:
            raise ValueError(
                "Splitting requested but no audio file provided and no YouTube download source available."
            )

        if transcribe and not split:
            if no_chunks:
                # Direct transcription, need an audio file
                if audio_file is None:
                    raise ValueError(
                        "Transcription requested with no_chunks=True but no audio file provided."
                    )
            elif chunk_dir is None:
                raise ValueError(
                    "Transcription requested without splitting or downloading and no_chunks=False. Must provide --chunk_dir with pre-split chunks."
                )

    # Check no_chunks scenario:
    # no_chunks and split are mutually exclusive
    # If transcribing but not splitting or downloading:
    # If no_chunks and chunk_dir provided, it doesn't make sense since we won't use chunks at all.
    # 4. no_chunks flag validation:
    # no_chunks=False, we need chunks from chunk_dir
    if no_chunks:
        if split:
            raise ValueError(
                "Cannot use --no_chunks and --split together. Choose one option."
            )
        if chunk_dir is not None:
            raise ValueError("Cannot specify --chunk_dir when --no_chunks is set.")

    # 5. Boundaries flags:
    # If splitting is not requested but boundaries flags are set, it's meaningless.
    # The code snippet defaults whisper_boundaries to True, so if user tries to turn it off and sets silence?
    # We'll require that boundaries only matter if split is True.
    if not split and (silence_boundaries or whisper_boundaries):
        raise ValueError(
            "Boundary detection flags given but splitting is not requested. Remove these flags or enable --split."
        )

    # If split is True, we must have a consistent boundary method:
    if split:
        # If both whisper and silence are somehow True:
        if silence_boundaries and whisper_boundaries:
            raise ValueError(
                "Cannot use both --silence_boundaries and --whisper_boundaries simultaneously."
            )

        # If both are False:
        # Given the original snippet, whisper_boundaries is True by default.
        # For the sake of robustness, let's say if user sets both off, we can't proceed:
        if not silence_boundaries and not whisper_boundaries:
            raise ValueError(
                "No boundary method selected for splitting. Enable either whisper or silence boundaries."
            )
version_check
logger = get_child_logger(__name__) module-attribute
YTDVersionChecker

Simple version checker for yt-dlp with robust version comparison.

This is a prototype implementation may need expansion in these areas: - Caching to prevent frequent PyPI calls - More comprehensive error handling for: - Missing/uninstalled packages - Network timeouts - JSON parsing errors - Invalid version strings - Environment detection (virtualenv, conda, system Python) - Configuration options for version pinning - Proxy support for network requests

Source code in src/tnh_scholar/cli_tools/audio_transcribe/version_check.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class YTDVersionChecker:
    """
    Simple version checker for yt-dlp with robust version comparison.

    This is a prototype implementation may need expansion in these areas:
    - Caching to prevent frequent PyPI calls
    - More comprehensive error handling for:
        - Missing/uninstalled packages
        - Network timeouts
        - JSON parsing errors
        - Invalid version strings
    - Environment detection (virtualenv, conda, system Python)
    - Configuration options for version pinning
    - Proxy support for network requests
    """

    PYPI_URL = "https://pypi.org/pypi/yt-dlp/json"
    NETWORK_TIMEOUT = 5  # seconds

    def _get_installed_version(self) -> Version:
        """
        Get installed yt-dlp version.

        Returns:
            Version object representing installed version

        Raises:
            ImportError: If yt-dlp is not installed
            InvalidVersion: If installed version string is invalid
        """
        try:
            if version_str := str(importlib.metadata.version("yt-dlp")):
                return Version(version_str)
            else:
                raise InvalidVersion("yt-dlp version string is empty")
        except importlib.metadata.PackageNotFoundError as e:
            raise ImportError("yt-dlp is not installed") from e
        except InvalidVersion:
            raise

    def _get_latest_version(self) -> Version:
        """
        Get latest version from PyPI.

        Returns:
            Version object representing latest available version

        Raises:
            requests.RequestException: For any network-related errors
            InvalidVersion: If PyPI version string is invalid
            KeyError: If PyPI response JSON is malformed
        """
        try:
            response = requests.get(self.PYPI_URL, timeout=self.NETWORK_TIMEOUT)
            response.raise_for_status()
            version_str = response.json()["info"]["version"]
            return Version(version_str)
        except requests.RequestException as e:
            raise requests.RequestException(
                "Failed to fetch version from PyPI. Check network connection."
            ) from e

    def check_version(self) -> Tuple[bool, Version, Version]:
        """
        Check if yt-dlp needs updating.

        Returns:
            Tuple of (needs_update, installed_version, latest_version)

        Raises:
            ImportError: If yt-dlp is not installed
            requests.RequestException: For network-related errors
            InvalidVersion: If version strings are invalid
        """
        installed_version = self._get_installed_version()
        latest_version = self._get_latest_version()

        needs_update = installed_version < latest_version
        return needs_update, installed_version, latest_version
NETWORK_TIMEOUT = 5 class-attribute instance-attribute
PYPI_URL = 'https://pypi.org/pypi/yt-dlp/json' class-attribute instance-attribute
check_version()

Check if yt-dlp needs updating.

Returns:

Type Description
Tuple[bool, Version, Version]

Tuple of (needs_update, installed_version, latest_version)

Raises:

Type Description
ImportError

If yt-dlp is not installed

RequestException

For network-related errors

InvalidVersion

If version strings are invalid

Source code in src/tnh_scholar/cli_tools/audio_transcribe/version_check.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def check_version(self) -> Tuple[bool, Version, Version]:
    """
    Check if yt-dlp needs updating.

    Returns:
        Tuple of (needs_update, installed_version, latest_version)

    Raises:
        ImportError: If yt-dlp is not installed
        requests.RequestException: For network-related errors
        InvalidVersion: If version strings are invalid
    """
    installed_version = self._get_installed_version()
    latest_version = self._get_latest_version()

    needs_update = installed_version < latest_version
    return needs_update, installed_version, latest_version
check_ytd_version()

Check if yt-dlp needs updating and log appropriate messages.

This function checks the installed version of yt-dlp against the latest version on PyPI and logs informational or error messages as appropriate. It handles network errors, missing packages, and version parsing issues gracefully.

The function does not raise exceptions but logs them using the application's logging system.

Source code in src/tnh_scholar/cli_tools/audio_transcribe/version_check.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def check_ytd_version() -> bool:
    """
    Check if yt-dlp needs updating and log appropriate messages.

    This function checks the installed version of yt-dlp against the latest version
    on PyPI and logs informational or error messages as appropriate. It handles
    network errors, missing packages, and version parsing issues gracefully.

    The function does not raise exceptions but logs them using the application's
    logging system.
    """
    checker = YTDVersionChecker()
    try:
        needs_update, current, latest = checker.check_version()
        if needs_update:
            logger.info(f"Update available: {current} -> {latest}")
            logger.info("Please run the appropriate upgrade in your environment.")
            logger.info("   For example: pip install --upgrade yt-dlp ")
            return False
        else:
            logger.info(f"yt-dlp is up to date (version {current})")

    except ImportError as e:
        logger.error(f"In yt-dlp version check: Package error: {e}")
    except requests.RequestException as e:
        logger.error(f"In yt-dlp version check: Network error: {e}")
    except InvalidVersion as e:
        logger.error(f"In yt-dlp version check: Version parsing error: {e}")
    except Exception as e:
        logger.error(f"In yt-dlp version check: Unexpected error: {e}")

    return True

json_to_srt

__all__ = ['main', 'json_to_srt'] module-attribute
main()

Entry point for the jsonl-to-srt CLI tool.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
165
166
167
def main():
    """Entry point for the jsonl-to-srt CLI tool."""
    json_to_srt()
json_to_srt

Simple CLI tool for converting JSONL transcription files to SRT format.

This module provides a command line interface for transforming JSONL transcription files (from audio-transcribe) into SRT subtitle format. Handles chunked transcriptions with proper timestamp accumulation.

logger = get_child_logger(__name__) module-attribute
JsonlToSrtConverter

Converts JSONL transcription files from audio-transcribe to SRT format.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class JsonlToSrtConverter:
    """Converts JSONL transcription files from audio-transcribe to SRT format."""

    def __init__(self):
        """Initialize converter state."""
        self.entry_index = 1
        self.accumulated_time = 0.0

    def format_timestamp(self, seconds: float) -> str:
        """Convert seconds to SRT timestamp format (HH:MM:SS,mmm)."""
        td = timedelta(seconds=seconds)
        hours, remainder = divmod(td.seconds, 3600)
        minutes, seconds = divmod(remainder, 60)
        milliseconds = round(td.microseconds / 1000)
        return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"

    def parse_jsonl_line(self, line: str) -> Dict:
        """Parse a single JSONL line into a dictionary."""
        try:
            return json.loads(line.strip())
        except json.JSONDecodeError as e:
            logger.error(f"Error parsing JSONL line: {e}")
            return {}

    def build_srt_entry(self, index: int, start: float, end: float, text: str) -> str:
        """Format a single SRT entry."""
        start_str = self.format_timestamp(start)
        end_str = self.format_timestamp(end)
        return f"{index}\n{start_str} --> {end_str}\n{text}\n"

    def extract_segment_data(self, segment: Dict) -> Tuple[float, float, str]:
        """Extract timestamp and text data from a segment."""
        start = segment.get("start", 0) + self.accumulated_time
        end = segment.get("end", 0) + self.accumulated_time
        text = segment.get("text", "").strip()
        return start, end, text

    def process_segment(self, segment: Dict) -> Optional[str]:
        """Process a single segment into SRT format."""
        start, end, text = self.extract_segment_data(segment)

        if not text:
            return None

        entry = self.build_srt_entry(self.entry_index, start, end, text)
        self.entry_index += 1
        return entry

    def process_segments_list(self, segments_list: List[Dict]) -> List[str]:
        """Process a list of segments into SRT entries."""
        entries = []

        for segment in segments_list:
            if entry := self.process_segment(segment):
                entries.append(entry)

        return entries

    def get_segments_from_data(self, data: Dict) -> List[Dict]:
        """Extract segments from a data object."""
        return data.get("segments", [])

    def read_input_lines(self, input_file: TextIO) -> List[str]:
        """Read and filter input lines from file."""
        return [line.strip() for line in input_file if line.strip()]

    def process_jsonl_line(self, line: str) -> List[str]:
        """Process a single JSONL line into SRT entries."""
        data = self.parse_jsonl_line(line)
        if not data:
            return []

        # Extract duration for accumulation
        chunk_duration = data.get("duration", 0.0)

        segments = self.get_segments_from_data(data)
        entries = self.process_segments_list(segments)

        # Update accumulated time after processing this chunk
        self.accumulated_time += chunk_duration
        return entries

    def process_jsonl_content(self, lines: List[str]) -> str:
        """Process all JSONL content into SRT format."""
        all_entries = []

        for line in lines:
            entries = self.process_jsonl_line(line)
            all_entries.extend(entries)

        return "\n".join(all_entries)

    def handle_output(self, srt_content: str, output_file: Optional[Path]) -> None:
        """Write SRT content to file or stdout."""
        if output_file:
            write_str_to_file(output_file, srt_content, overwrite=True)
            logger.info(f"SRT content written to {output_file}")
        else:
            click.echo(srt_content)

    def convert(self, input_file: TextIO, output_file: Optional[Path] = None) -> str:
        """
        Convert a JSONL transcription file to SRT format.

        Args:
            input_file: JSONL transcription file to parse
            output_file: Optional output file path

        Returns:
            str: SRT formatted content
        """
        input_lines = self.read_input_lines(input_file)
        srt_content = self.process_jsonl_content(input_lines)
        self.handle_output(srt_content, output_file)
        return srt_content
accumulated_time = 0.0 instance-attribute
entry_index = 1 instance-attribute
__init__()

Initialize converter state.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
28
29
30
31
def __init__(self):
    """Initialize converter state."""
    self.entry_index = 1
    self.accumulated_time = 0.0
build_srt_entry(index, start, end, text)

Format a single SRT entry.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
49
50
51
52
53
def build_srt_entry(self, index: int, start: float, end: float, text: str) -> str:
    """Format a single SRT entry."""
    start_str = self.format_timestamp(start)
    end_str = self.format_timestamp(end)
    return f"{index}\n{start_str} --> {end_str}\n{text}\n"
convert(input_file, output_file=None)

Convert a JSONL transcription file to SRT format.

Parameters:

Name Type Description Default
input_file TextIO

JSONL transcription file to parse

required
output_file Optional[Path]

Optional output file path

None

Returns:

Name Type Description
str str

SRT formatted content

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def convert(self, input_file: TextIO, output_file: Optional[Path] = None) -> str:
    """
    Convert a JSONL transcription file to SRT format.

    Args:
        input_file: JSONL transcription file to parse
        output_file: Optional output file path

    Returns:
        str: SRT formatted content
    """
    input_lines = self.read_input_lines(input_file)
    srt_content = self.process_jsonl_content(input_lines)
    self.handle_output(srt_content, output_file)
    return srt_content
extract_segment_data(segment)

Extract timestamp and text data from a segment.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
55
56
57
58
59
60
def extract_segment_data(self, segment: Dict) -> Tuple[float, float, str]:
    """Extract timestamp and text data from a segment."""
    start = segment.get("start", 0) + self.accumulated_time
    end = segment.get("end", 0) + self.accumulated_time
    text = segment.get("text", "").strip()
    return start, end, text
format_timestamp(seconds)

Convert seconds to SRT timestamp format (HH:MM:SS,mmm).

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
33
34
35
36
37
38
39
def format_timestamp(self, seconds: float) -> str:
    """Convert seconds to SRT timestamp format (HH:MM:SS,mmm)."""
    td = timedelta(seconds=seconds)
    hours, remainder = divmod(td.seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    milliseconds = round(td.microseconds / 1000)
    return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
get_segments_from_data(data)

Extract segments from a data object.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
83
84
85
def get_segments_from_data(self, data: Dict) -> List[Dict]:
    """Extract segments from a data object."""
    return data.get("segments", [])
handle_output(srt_content, output_file)

Write SRT content to file or stdout.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
117
118
119
120
121
122
123
def handle_output(self, srt_content: str, output_file: Optional[Path]) -> None:
    """Write SRT content to file or stdout."""
    if output_file:
        write_str_to_file(output_file, srt_content, overwrite=True)
        logger.info(f"SRT content written to {output_file}")
    else:
        click.echo(srt_content)
parse_jsonl_line(line)

Parse a single JSONL line into a dictionary.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
41
42
43
44
45
46
47
def parse_jsonl_line(self, line: str) -> Dict:
    """Parse a single JSONL line into a dictionary."""
    try:
        return json.loads(line.strip())
    except json.JSONDecodeError as e:
        logger.error(f"Error parsing JSONL line: {e}")
        return {}
process_jsonl_content(lines)

Process all JSONL content into SRT format.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
107
108
109
110
111
112
113
114
115
def process_jsonl_content(self, lines: List[str]) -> str:
    """Process all JSONL content into SRT format."""
    all_entries = []

    for line in lines:
        entries = self.process_jsonl_line(line)
        all_entries.extend(entries)

    return "\n".join(all_entries)
process_jsonl_line(line)

Process a single JSONL line into SRT entries.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def process_jsonl_line(self, line: str) -> List[str]:
    """Process a single JSONL line into SRT entries."""
    data = self.parse_jsonl_line(line)
    if not data:
        return []

    # Extract duration for accumulation
    chunk_duration = data.get("duration", 0.0)

    segments = self.get_segments_from_data(data)
    entries = self.process_segments_list(segments)

    # Update accumulated time after processing this chunk
    self.accumulated_time += chunk_duration
    return entries
process_segment(segment)

Process a single segment into SRT format.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
62
63
64
65
66
67
68
69
70
71
def process_segment(self, segment: Dict) -> Optional[str]:
    """Process a single segment into SRT format."""
    start, end, text = self.extract_segment_data(segment)

    if not text:
        return None

    entry = self.build_srt_entry(self.entry_index, start, end, text)
    self.entry_index += 1
    return entry
process_segments_list(segments_list)

Process a list of segments into SRT entries.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
73
74
75
76
77
78
79
80
81
def process_segments_list(self, segments_list: List[Dict]) -> List[str]:
    """Process a list of segments into SRT entries."""
    entries = []

    for segment in segments_list:
        if entry := self.process_segment(segment):
            entries.append(entry)

    return entries
read_input_lines(input_file)

Read and filter input lines from file.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
87
88
89
def read_input_lines(self, input_file: TextIO) -> List[str]:
    """Read and filter input lines from file."""
    return [line.strip() for line in input_file if line.strip()]
json_to_srt(input_file, output=None)

Convert JSONL transcription files to SRT subtitle format.

Reads from stdin if no INPUT_FILE is specified. Writes to stdout if no output file is specified.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
@click.command()
@click.argument("input_file", type=click.File("r"), default="-")
@click.option(
    "-o", 
    "--output", 
    type=click.Path(path_type=Path), 
    help="Output file (default: stdout)"
)
def json_to_srt(input_file: TextIO, output: Optional[Path] = None) -> None:
    """
    Convert JSONL transcription files to SRT subtitle format.

    Reads from stdin if no INPUT_FILE is specified.
    Writes to stdout if no output file is specified.
    """
    try:
        converter = JsonlToSrtConverter()
        converter.convert(input_file, output)
    except Exception as e:
        logger.error(f"Error processing file: {e}")
        sys.exit(1)
main()

Entry point for the jsonl-to-srt CLI tool.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt.py
165
166
167
def main():
    """Entry point for the jsonl-to-srt CLI tool."""
    json_to_srt()
json_to_srt1

Simple CLI tool for converting JSONL transcription files to SRT format.

This module provides a command line interface for transforming JSONL transcription files (from audio-transcribe) into SRT subtitle format.

logger = get_child_logger(__name__) module-attribute
convert_to_srt(input_file, output_file=None)

Convert a JSONL transcription file to SRT format.

Parameters:

Name Type Description Default
input_file TextIO

JSONL transcription file to parse

required
output_file Optional[Path]

Optional output file path

None

Returns:

Name Type Description
str str

SRT formatted content

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt1.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def convert_to_srt(input_file: TextIO, output_file: Optional[Path] = None) -> str:
    """
    Convert a JSONL transcription file to SRT format.

    Args:
        input_file: JSONL transcription file to parse
        output_file: Optional output file path

    Returns:
        str: SRT formatted content
    """
    input_lines = read_input_lines(input_file)
    srt_content = process_jsonl_content(input_lines)
    handle_output(srt_content, output_file)
    return srt_content
extract_segment_data(segment)

Extract timestamp and text data from a segment.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt1.py
44
45
46
47
48
49
def extract_segment_data(segment: Dict) -> Tuple[float, float, str]:
    """Extract timestamp and text data from a segment."""
    start = segment.get("start", 0)
    end = segment.get("end", 0)
    text = segment.get("text", "").strip()
    return start, end, text
format_srt_entry(index, start, end, text)

Format a single SRT entry.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt1.py
38
39
40
41
42
def format_srt_entry(index: int, start: float, end: float, text: str) -> str:
    """Format a single SRT entry."""
    start_str = format_timestamp(start)
    end_str = format_timestamp(end)
    return f"{index}\n{start_str} --> {end_str}\n{text}\n"
format_timestamp(seconds)

Convert seconds to SRT timestamp format (HH:MM:SS,mmm).

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt1.py
22
23
24
25
26
27
28
def format_timestamp(seconds: float) -> str:
    """Convert seconds to SRT timestamp format (HH:MM:SS,mmm)."""
    td = timedelta(seconds=seconds)
    hours, remainder = divmod(td.seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    milliseconds = round(td.microseconds / 1000)
    return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
get_segments_from_data(data)

Extract segments from a data object.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt1.py
73
74
75
def get_segments_from_data(data: Dict) -> List[Dict]:
    """Extract segments from a data object."""
    return data["segments"] if "segments" in data else []
handle_output(srt_content, output_file)

Write SRT content to file or stdout.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt1.py
106
107
108
109
110
111
112
def handle_output(srt_content: str, output_file: Optional[Path]) -> None:
    """Write SRT content to file or stdout."""
    if output_file:
        write_str_to_file(output_file, srt_content)
        logger.info(f"SRT content written to {output_file}")
    else:
        click.echo(srt_content)
json_to_srt(input_file, output=None)

Convert JSONL transcription files to SRT subtitle format.

Reads from stdin if no INPUT_FILE is specified. Writes to stdout if no output file is specified.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt1.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
@click.command()
@click.argument("input_file", type=click.File("r"), default="-")
@click.option(
    "-o", 
    "--output", 
    type=click.Path(path_type=Path), 
    help="Output file (default: stdout)"
)
def json_to_srt(input_file: TextIO, output: Optional[Path] = None) -> None:
    """
    Convert JSONL transcription files to SRT subtitle format.

    Reads from stdin if no INPUT_FILE is specified.
    Writes to stdout if no output file is specified.
    """
    try:
        convert_to_srt(input_file, output)
    except Exception as e:
        logger.error(f"Error processing file: {e}")
        sys.exit(1)
main()

Entry point for the jsonl-to-srt CLI tool.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt1.py
151
152
153
def main():
    """Entry point for the jsonl-to-srt CLI tool."""
    json_to_srt()
parse_jsonl_line(line)

Parse a single JSONL line into a dictionary.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt1.py
30
31
32
33
34
35
36
def parse_jsonl_line(line: str) -> Dict:
    """Parse a single JSONL line into a dictionary."""
    try:
        return json.loads(line.strip())
    except json.JSONDecodeError as e:
        logger.error(f"Error parsing JSONL line: {e}")
        return {}
process_jsonl_content(lines)

Process all JSONL content into SRT format.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt1.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def process_jsonl_content(lines: List[str]) -> str:
    """Process all JSONL content into SRT format."""
    all_entries = []
    entry_index = 1
    accumulated_time = 0.0  # Track total duration of processed chunks

    for line in lines:
        entries, entry_index, chunk_duration = process_jsonl_line(
            line, entry_index, accumulated_time)
        all_entries.extend(entries)

        # Add this chunk's duration to accumulated time
        accumulated_time += chunk_duration  

    return "\n".join(all_entries)
process_jsonl_line(line, entry_index)

Process a single JSONL line into SRT entries.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt1.py
81
82
83
84
85
86
87
88
def process_jsonl_line(line: str, entry_index: int) -> Tuple[List[str], int]:
    """Process a single JSONL line into SRT entries."""
    data = parse_jsonl_line(line)
    if not data:
        return [], entry_index

    segments = get_segments_from_data(data)
    return process_segments_list(segments, entry_index)
process_segment(segment, entry_index)

Process a single segment into SRT format.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt1.py
51
52
53
54
55
56
57
58
59
def process_segment(segment: Dict, entry_index: int) -> Tuple[str, int]:
    """Process a single segment into SRT format."""
    start, end, text = extract_segment_data(segment)

    if not text:
        return "", entry_index

    entry = format_srt_entry(entry_index, start, end, text)
    return entry, entry_index + 1
process_segments_list(segments_list, entry_index)

Process a list of segments into SRT entries.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt1.py
61
62
63
64
65
66
67
68
69
70
71
def process_segments_list(segments_list: List[Dict], 
                          entry_index: int) -> Tuple[List[str], int]:
    """Process a list of segments into SRT entries."""
    entries = []

    for segment in segments_list:
        entry, entry_index = process_segment(segment, entry_index)
        if entry:
            entries.append(entry)

    return entries, entry_index
read_input_lines(input_file)

Read and filter input lines from file.

Source code in src/tnh_scholar/cli_tools/json_to_srt/json_to_srt1.py
77
78
79
def read_input_lines(input_file: TextIO) -> List[str]:
    """Read and filter input lines from file."""
    return [line.strip() for line in input_file if line.strip()]

nfmt

nfmt
main()

Entry point for the nfmt CLI tool.

Source code in src/tnh_scholar/cli_tools/nfmt/nfmt.py
24
25
26
def main():
    """Entry point for the nfmt CLI tool."""
    nfmt()
nfmt(input_file, output, spacing)

Normalize the number of newlines in a text file.

Source code in src/tnh_scholar/cli_tools/nfmt/nfmt.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
@click.command()
@click.argument("input_file", type=click.File("r"), default="-")
@click.option(
    "-o",
    "--output",
    type=click.File("w"),
    default="-",
    help="Output file (default: stdout)",
)
@click.option(
    "-s", "--spacing", default=2, help="Number of newlines between blocks (default: 2)"
)
def nfmt(input_file, output, spacing):
    """Normalize the number of newlines in a text file."""
    text = input_file.read()
    result = normalize_newlines(text, spacing)
    output.write(result)

sent_split

sent_split

Simple CLI tool for sentence splitting.

This module provides a command line interface for splitting text into sentences. Uses NLTK for robust sentence tokenization. Reads from stdin and writes to stdout by default, with optional file input/output.

SplitConfig

Bases: BaseModel

Source code in src/tnh_scholar/cli_tools/sent_split/sent_split.py
26
27
28
class SplitConfig(BaseModel):
    separator: Literal["space", "newline"] = "newline"
    nltk_tokenizer: str = "punkt"
nltk_tokenizer = 'punkt' class-attribute instance-attribute
separator = 'newline' class-attribute instance-attribute
SplitIOData

Bases: BaseModel

Source code in src/tnh_scholar/cli_tools/sent_split/sent_split.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class SplitIOData(BaseModel):
    input_path: Optional[Path] = None
    output_path: Optional[Path] = None
    content: Optional[str] = None

    @classmethod
    def from_io(
        cls, input_file: Optional[Path], output: Optional[Path]
        ) -> "SplitIOData":
        return cls(input_path=input_file, output_path=output)

    def get_input_content(self) -> str:
        if self.content is not None:
            return self.content
        return read_str_from_file(self.input_path) if self.input_path \
            else sys.stdin.read()

    def write_output(self, result: SplitResult) -> None:
        text = result.text_object
        output_str = str(text)
        if self.output_path:
            write_str_to_file(self.output_path, output_str)
            click.echo(f"Output written to: {self.output_path.name}")
            click.echo(f"Split into {result.stats['sentence_count']} sentences.")
        else:
            click.echo(output_str)
content = None class-attribute instance-attribute
input_path = None class-attribute instance-attribute
output_path = None class-attribute instance-attribute
from_io(input_file, output) classmethod
Source code in src/tnh_scholar/cli_tools/sent_split/sent_split.py
40
41
42
43
44
@classmethod
def from_io(
    cls, input_file: Optional[Path], output: Optional[Path]
    ) -> "SplitIOData":
    return cls(input_path=input_file, output_path=output)
get_input_content()
Source code in src/tnh_scholar/cli_tools/sent_split/sent_split.py
46
47
48
49
50
def get_input_content(self) -> str:
    if self.content is not None:
        return self.content
    return read_str_from_file(self.input_path) if self.input_path \
        else sys.stdin.read()
write_output(result)
Source code in src/tnh_scholar/cli_tools/sent_split/sent_split.py
52
53
54
55
56
57
58
59
60
def write_output(self, result: SplitResult) -> None:
    text = result.text_object
    output_str = str(text)
    if self.output_path:
        write_str_to_file(self.output_path, output_str)
        click.echo(f"Output written to: {self.output_path.name}")
        click.echo(f"Split into {result.stats['sentence_count']} sentences.")
    else:
        click.echo(output_str)
SplitResult
Source code in src/tnh_scholar/cli_tools/sent_split/sent_split.py
30
31
32
33
@dataclass
class SplitResult:
    text_object: TextObject
    stats: Dict[str, Any] = {}
stats = {} class-attribute instance-attribute
text_object instance-attribute
ensure_nltk_data(config)
Source code in src/tnh_scholar/cli_tools/sent_split/sent_split.py
62
63
64
65
66
67
68
69
70
71
72
73
74
def ensure_nltk_data(config: SplitConfig) -> None:
    try:
        nltk.data.find(f'tokenizers/{config.nltk_tokenizer}')
    except LookupError:
        try:
            nltk.download(config.nltk_tokenizer, quiet=True)
            nltk.data.find(f'tokenizers/{config.nltk_tokenizer}')
        except Exception as e:
            raise RuntimeError(
                f"Failed to download required NLTK data. "
                f"Please run 'python -m nltk.downloader {config.nltk_tokenizer}' "
                f"to install manually. Error: {e}"
            ) from e
main()
Source code in src/tnh_scholar/cli_tools/sent_split/sent_split.py
125
126
def main():
    sent_split()
sent_split(input_file, output, space)
Source code in src/tnh_scholar/cli_tools/sent_split/sent_split.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
@click.command()
@click.argument(
    "input_file", type=click.Path(exists=True, path_type=Path), required=False
    )
@click.option('-o', '--output', type=click.Path(path_type=Path), required=False,
              help='Output file (default: stdout)')
@click.option('-s', '--space', is_flag=True,
              help='Separate sentences with spaces instead of newlines')
def sent_split(input_file: Optional[Path],
               output: Optional[Path],
               space: bool) -> None:
    try:
        io_data = SplitIOData.from_io(input_file, output)
        config = SplitConfig(separator="space" if space else "newline")

        input_text = io_data.get_input_content()
        text = TextObject.from_str(input_text)

        result = split_text(text, config, io_data)
        io_data.write_output(result)

    except Exception as e:
        click.echo(f"Error processing text: {e}", err=True)
        sys.exit(1)
split_text(text, config, io_data)
Source code in src/tnh_scholar/cli_tools/sent_split/sent_split.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def split_text(
    text: TextObject, config: SplitConfig, io_data: SplitIOData
    ) -> SplitResult:
    ensure_nltk_data(config)
    sentences = sent_tokenize(text.content)

    separator = "\n" if config.separator == "newline" else " "
    new_content = separator.join(sentences)

    text.transform(
        data_str=new_content,
        process_metadata=ProcessMetadata(
            step="split_text",
            processor="NLTK",
            tool="sent-split",
            source_file=io_data.input_path or None,
        )
    )

    return SplitResult(
        text_object=text,
        stats={"sentence_count": len(sentences)}
    )
sent_split_bak

Simple CLI tool for sentence splitting.

This module provides a command line interface for splitting text into sentences. Uses NLTK for robust sentence tokenization. Reads from stdin and writes to stdout by default, with optional file input/output.

ensure_nltk_data()

Ensure NLTK punkt tokenizer is available.

Source code in src/tnh_scholar/cli_tools/sent_split/sent_split_bak.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def ensure_nltk_data():
    """Ensure NLTK punkt tokenizer is available."""
    try:
        # Try to find the resource
        nltk.data.find('tokenizers/punkt')
    except LookupError:
        # If not found, try downloading
        try:
            nltk.download('punkt', quiet=True)
            # Verify download
            nltk.data.find('tokenizers/punkt')
        except Exception as e:
            raise RuntimeError(
                "Failed to download required NLTK data. "
                "Please run 'python -m nltk.downloader punkt' "
                f"to install manually. Error: {e}"
            ) from e
main()
Source code in src/tnh_scholar/cli_tools/sent_split/sent_split_bak.py
100
101
def main():
    sent_split()
process_text(text, newline=True)

Split text into sentences using NLTK.

Source code in src/tnh_scholar/cli_tools/sent_split/sent_split_bak.py
46
47
48
49
50
51
52
def process_text(text: TextObject, newline: bool = True) -> None:
    """Split text into sentences using NLTK."""
    ensure_nltk_data()
    sentences = sent_tokenize(text.content)

    new_content = "\n".join(sentences) if newline else " ".join(sentences)
    text.transform(data_str=new_content)
sent_split(input_file, output, space)

Split text into sentences using NLTK's sentence tokenizer.

Reads from stdin if no input file is specified. Writes to stdout if no output file is specified.

Source code in src/tnh_scholar/cli_tools/sent_split/sent_split_bak.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@click.command()
@click.argument(
    "input_file", type=click.Path(exists=True, path_type=Path), required=False
    )
@click.option('-o', '--output', type=click.Path(path_type=Path), required=False,
              help='Output file (default: stdout)')
@click.option('-s', '--space', is_flag=True,
              help='Separate sentences with spaces instead of newlines')
def sent_split(input_file: Optional[Path],
               output: Optional[Path],
               space: bool) -> None:
    """Split text into sentences using NLTK's sentence tokenizer.

    Reads from stdin if no input file is specified.
    Writes to stdout if no output file is specified.
    """
    try:
        # Read from file or stdin
        input_text = read_str_from_file(input_file) if input_file else sys.stdin.read()

        # Process the text
        text = TextObject.from_str(input_text)
        process_text(text, newline=not space)

        process_metadata = ProcessMetadata(
            step="sentence-split",
            processor="NLTK", 
        )
        if input_file:
            process_metadata.update({"source_file": path_as_str(input_file)})

        text.transform(process_metadata=process_metadata)

        # Write to file or stdout
        if output:
            write_str_to_file(output, str(text))
        else:
            click.echo(text)

        if output:
            click.echo(f"Output written to: {output.name}")

    except Exception as e:
        click.echo(f"Error processing text: {e}", err=True)
        sys.exit(1)

srt_translate

__all__ = ['main', 'srt_translate'] module-attribute
main()

Entry point for the srt-translate CLI tool.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
262
263
264
def main():
    """Entry point for the srt-translate CLI tool."""
    srt_translate()
srt_translate

CLI tool for translating SRT subtitle files using tnh-scholar line translation.

This module provides a command line interface for translating SRT subtitle files from one language to another while preserving timecodes and subtitle structure. Uses the same translation engine as tnh-fab translate.

logger = get_child_logger(__name__) module-attribute
SrtEntry

Represents a single subtitle entry from an SRT file.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class SrtEntry:
    """Represents a single subtitle entry from an SRT file."""

    def __init__(self, index: int, start_time: str, end_time: str, text: str):
        """Initialize subtitle entry with timing and text."""
        self.index = index
        self.start_time = start_time
        self.end_time = end_time
        self.text = text.strip()

    def __str__(self) -> str:
        """Format entry as SRT text."""
        return f"{self.index}\n{self.start_time} --> {self.end_time}\n{self.text}\n"

    @property
    def line_key(self) -> str:
        """Generate a unique line key for this entry."""
        return f"{self.index}"
end_time = end_time instance-attribute
index = index instance-attribute
line_key property

Generate a unique line key for this entry.

start_time = start_time instance-attribute
text = text.strip() instance-attribute
__init__(index, start_time, end_time, text)

Initialize subtitle entry with timing and text.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
31
32
33
34
35
36
def __init__(self, index: int, start_time: str, end_time: str, text: str):
    """Initialize subtitle entry with timing and text."""
    self.index = index
    self.start_time = start_time
    self.end_time = end_time
    self.text = text.strip()
__str__()

Format entry as SRT text.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
38
39
40
def __str__(self) -> str:
    """Format entry as SRT text."""
    return f"{self.index}\n{self.start_time} --> {self.end_time}\n{self.text}\n"
SrtTranslator

Translates SRT files while preserving timecodes.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class SrtTranslator:
    """Translates SRT files while preserving timecodes."""

    def __init__(self, 
                 source_language: Optional[str] = None,
                 target_language: str = "en",
                 pattern: Optional[Prompt] = None,
                 model: Optional[str] = None,
                 metadata: Optional[Metadata] = None):
        """Initialize translator with language, model settings, and metadata."""
        self.source_language = source_language
        self.target_language = target_language
        self.pattern = pattern
        self.model = model
        self.metadata = metadata

    def parse_srt(self, content: str) -> List[SrtEntry]:
        """Parse SRT content into structured entries."""
        # Pattern matches: index, start time, end time, and multiline text
        pattern = r"(\d+)\r?\n(\d{2}:\d{2}:\d{2},\d{3}) --> (\d{2}:\d{2}:\d{2},\d{3})\r?\n((?:.+(?:\r?\n))+)(?:\r?\n)?"  # noqa: E501
        matches = re.findall(pattern, content, re.MULTILINE)

        entries = []
        for match in matches:
            index = int(match[0])
            start_time = match[1]
            end_time = match[2]
            text = match[3].strip()
            entries.append(SrtEntry(index, start_time, end_time, text))

        logger.info(f"Parsed {len(entries)} subtitle entries")
        return entries

    def entries_to_numbered_text(self, entries: List[SrtEntry]) -> str:
        """Convert SRT entries to numbered text for TextObject."""
        lines = []
        lines.extend(f"{entry.text}" for entry in entries)
        return "\n".join(lines)

    def create_text_object(self, text: str) -> TextObject:
        """Create a TextObject from the extracted SRT text with metadata."""
        return TextObject.from_str(
            text, language=self.source_language, metadata=self.metadata
            )

    def translate_text_object(self, text_object: TextObject) -> TextObject:
        """Translate the TextObject using line translation."""
        text_obj = translate_text_by_lines(
            text_object,
            source_language=self.source_language,
            target_language=self.target_language,
            pattern=self.pattern,
            model=self.model
        )
        logger.debug("Text generated: \n"
                      f"{text_obj}")
        return text_obj

    def extract_translated_lines(self, translated_object: TextObject) -> Dict[str, str]:
        """Extract translated lines from TextObject with line keys."""
        # Get the properly numbered content instead of raw content
        numbered_translation = translated_object.numbered_content
        logger.debug(f"Numbered translated text sample "
                     f":\n{numbered_translation[:500]}...")

        # Pattern matches line numbers and their text, 
        # accounting for the numbering format.
        # This depends on a consistent pattern for the lines.
        # This pattern will match the format like "1: Translated text"
        pattern = rf"(\d+){re.escape(translated_object.num_text.separator)}(.*)"

        translations = {}
        for line in numbered_translation.splitlines():
            if match := re.match(pattern, line):
                line_key = match[1]
                text = match[2].strip()
                translations[line_key] = text
                logger.debug(f"Found translation for key {line_key}: {text[:50]}...")

        logger.debug(f"Extracted {len(translations)} translations")
        return translations

    def update_entries_with_translations(self, 
                                        entries: List[SrtEntry], 
                                        translations: Dict[str, str]) -> List[SrtEntry]:
        """Apply translations to original entries."""
        updated_entries = []
        for entry in entries:
            # Look up translation by line key
            if entry.line_key in translations:
                entry.text = translations[entry.line_key]
            updated_entries.append(entry)

        return updated_entries

    def format_srt(self, entries: List[SrtEntry]) -> str:
        """Format entries back to SRT content."""
        return "\n".join(str(entry) for entry in entries)

    def translate_srt(self, content: str) -> str:
        """Process SRT content through complete translation pipeline."""
        entries = self.parse_srt(content)
        numbered_text = self.entries_to_numbered_text(entries)
        text_object = self.create_text_object(numbered_text)

        logger.info(f"Translating from {self.source_language or 'auto-detected'} "
                    f"to {self.target_language}")
        translated_object = self.translate_text_object(text_object)

        translations = self.extract_translated_lines(translated_object)
        updated_entries = self.update_entries_with_translations(entries, translations)

        return self.format_srt(updated_entries)

    def translate_and_save(self, input_file: Path, output_path: Path):
        """Handles file reading, translation, and saving."""

        content = read_str_from_file(input_file)
        logger.info(f"Reading SRT file: {input_file}")

        translated_content = self.translate_srt(content)

        write_str_to_file(output_path, translated_content, overwrite=True)
        logger.info(f"Translated SRT written to: {output_path}")
metadata = metadata instance-attribute
model = model instance-attribute
pattern = pattern instance-attribute
source_language = source_language instance-attribute
target_language = target_language instance-attribute
__init__(source_language=None, target_language='en', pattern=None, model=None, metadata=None)

Initialize translator with language, model settings, and metadata.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(self, 
             source_language: Optional[str] = None,
             target_language: str = "en",
             pattern: Optional[Prompt] = None,
             model: Optional[str] = None,
             metadata: Optional[Metadata] = None):
    """Initialize translator with language, model settings, and metadata."""
    self.source_language = source_language
    self.target_language = target_language
    self.pattern = pattern
    self.model = model
    self.metadata = metadata
create_text_object(text)

Create a TextObject from the extracted SRT text with metadata.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
87
88
89
90
91
def create_text_object(self, text: str) -> TextObject:
    """Create a TextObject from the extracted SRT text with metadata."""
    return TextObject.from_str(
        text, language=self.source_language, metadata=self.metadata
        )
entries_to_numbered_text(entries)

Convert SRT entries to numbered text for TextObject.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
81
82
83
84
85
def entries_to_numbered_text(self, entries: List[SrtEntry]) -> str:
    """Convert SRT entries to numbered text for TextObject."""
    lines = []
    lines.extend(f"{entry.text}" for entry in entries)
    return "\n".join(lines)
extract_translated_lines(translated_object)

Extract translated lines from TextObject with line keys.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def extract_translated_lines(self, translated_object: TextObject) -> Dict[str, str]:
    """Extract translated lines from TextObject with line keys."""
    # Get the properly numbered content instead of raw content
    numbered_translation = translated_object.numbered_content
    logger.debug(f"Numbered translated text sample "
                 f":\n{numbered_translation[:500]}...")

    # Pattern matches line numbers and their text, 
    # accounting for the numbering format.
    # This depends on a consistent pattern for the lines.
    # This pattern will match the format like "1: Translated text"
    pattern = rf"(\d+){re.escape(translated_object.num_text.separator)}(.*)"

    translations = {}
    for line in numbered_translation.splitlines():
        if match := re.match(pattern, line):
            line_key = match[1]
            text = match[2].strip()
            translations[line_key] = text
            logger.debug(f"Found translation for key {line_key}: {text[:50]}...")

    logger.debug(f"Extracted {len(translations)} translations")
    return translations
format_srt(entries)

Format entries back to SRT content.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
143
144
145
def format_srt(self, entries: List[SrtEntry]) -> str:
    """Format entries back to SRT content."""
    return "\n".join(str(entry) for entry in entries)
parse_srt(content)

Parse SRT content into structured entries.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def parse_srt(self, content: str) -> List[SrtEntry]:
    """Parse SRT content into structured entries."""
    # Pattern matches: index, start time, end time, and multiline text
    pattern = r"(\d+)\r?\n(\d{2}:\d{2}:\d{2},\d{3}) --> (\d{2}:\d{2}:\d{2},\d{3})\r?\n((?:.+(?:\r?\n))+)(?:\r?\n)?"  # noqa: E501
    matches = re.findall(pattern, content, re.MULTILINE)

    entries = []
    for match in matches:
        index = int(match[0])
        start_time = match[1]
        end_time = match[2]
        text = match[3].strip()
        entries.append(SrtEntry(index, start_time, end_time, text))

    logger.info(f"Parsed {len(entries)} subtitle entries")
    return entries
translate_and_save(input_file, output_path)

Handles file reading, translation, and saving.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
162
163
164
165
166
167
168
169
170
171
def translate_and_save(self, input_file: Path, output_path: Path):
    """Handles file reading, translation, and saving."""

    content = read_str_from_file(input_file)
    logger.info(f"Reading SRT file: {input_file}")

    translated_content = self.translate_srt(content)

    write_str_to_file(output_path, translated_content, overwrite=True)
    logger.info(f"Translated SRT written to: {output_path}")
translate_srt(content)

Process SRT content through complete translation pipeline.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def translate_srt(self, content: str) -> str:
    """Process SRT content through complete translation pipeline."""
    entries = self.parse_srt(content)
    numbered_text = self.entries_to_numbered_text(entries)
    text_object = self.create_text_object(numbered_text)

    logger.info(f"Translating from {self.source_language or 'auto-detected'} "
                f"to {self.target_language}")
    translated_object = self.translate_text_object(text_object)

    translations = self.extract_translated_lines(translated_object)
    updated_entries = self.update_entries_with_translations(entries, translations)

    return self.format_srt(updated_entries)
translate_text_object(text_object)

Translate the TextObject using line translation.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def translate_text_object(self, text_object: TextObject) -> TextObject:
    """Translate the TextObject using line translation."""
    text_obj = translate_text_by_lines(
        text_object,
        source_language=self.source_language,
        target_language=self.target_language,
        pattern=self.pattern,
        model=self.model
    )
    logger.debug("Text generated: \n"
                  f"{text_obj}")
    return text_obj
update_entries_with_translations(entries, translations)

Apply translations to original entries.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
130
131
132
133
134
135
136
137
138
139
140
141
def update_entries_with_translations(self, 
                                    entries: List[SrtEntry], 
                                    translations: Dict[str, str]) -> List[SrtEntry]:
    """Apply translations to original entries."""
    updated_entries = []
    for entry in entries:
        # Look up translation by line key
        if entry.line_key in translations:
            entry.text = translations[entry.line_key]
        updated_entries.append(entry)

    return updated_entries
load_metadata_from_file(metadata_file)

Load metadata from a file if provided.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def load_metadata_from_file(metadata_file: Optional[Path]) -> Optional[Metadata]:
    """Load metadata from a file if provided."""
    if not metadata_file:
        return None

    try:
        metadata, _ = Frontmatter.extract_from_file(metadata_file)
        logger.info(f"Loaded metadata from {metadata_file}")
        return metadata
    except FileNotFoundError:
        logger.error(f"Metadata file not found: {metadata_file}")
        exit(1)
    except Exception as e:
        logger.error(f"Failed to load metadata from {metadata_file}: {e}")
        exit(1)
main()

Entry point for the srt-translate CLI tool.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
262
263
264
def main():
    """Entry point for the srt-translate CLI tool."""
    srt_translate()
set_output_path(input_file, output, target_language)
Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
183
184
185
186
187
def set_output_path(input_file: Path, output: Optional[Path], target_language):
    if not output:
        lang_suffix = target_language
        return input_file.with_stem(f"{input_file.stem}_{lang_suffix}")
    return output
set_pattern(pattern)
Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
173
174
175
176
177
178
179
180
181
def set_pattern(pattern: Optional[str]):
    pattern_obj = None
    if pattern:
        try:
            pattern_obj = get_pattern(pattern)
        except Exception as e:
            logger.error(f"Failed to load pattern '{pattern}': {e}")
            sys.exit(1)
    return pattern_obj   
srt_translate(input_file, output=None, source_language=None, target_language='en', model=None, pattern=None, debug=False, metadata=None)

Translate SRT subtitle files from one language to another.

INPUT_FILE is the path to the SRT file to translate.

Source code in src/tnh_scholar/cli_tools/srt_translate/srt_translate.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
@click.command()
@click.argument("input_file", type=click.Path(exists=True, path_type=Path))
@click.option("-o", "--output", type=click.Path(path_type=Path),
              help="Output file path (default: adds language suffix to input filename)")
@click.option("-s", "--source-language", 
              help="Source language code (auto-detected if not specified)")
@click.option("-t", "--target-language", default="en", 
              help="Target language code (default: en)")
@click.option("-m", "--model", help="Optional model name to use for translation")
@click.option("-p", "--pattern", help="Optional translation pattern name")
@click.option("-g", "--debug", is_flag=True, help="Option to show debug output.")
@click.option("-d", "--metadata", type=click.Path(exists=True, path_type=Path),
              help="Path to file with YAML metadata as frontmatter, "
                   "providing translation context")
def srt_translate(
    input_file: Path,
    output: Optional[Path] = None,
    source_language: Optional[str] = None,
    target_language: str = "en",
    model: Optional[str] = None,
    pattern: Optional[str] = None,
    debug: Optional[bool] = False,
    metadata: Optional[Path] = None,
) -> None:
    """
    Translate SRT subtitle files from one language to another.

    INPUT_FILE is the path to the SRT file to translate.
    """

    if debug:
        setup_logging(log_level=logging.DEBUG)
    else:
        setup_logging()

    try:
        output_path = set_output_path(input_file, output, target_language)
        pattern_obj = set_pattern(pattern)
        if metadata_obj := load_metadata_from_file(metadata):
            logger.info(f"Using metadata for translation context from: {metadata}")


        translator = SrtTranslator(
            source_language=source_language,
            target_language=target_language,
            pattern=pattern_obj,
            model=model,
            metadata=metadata_obj,
        )

        translator.translate_and_save(input_file, output_path)

    except Exception as e:
        logger.error(f"Error translating SRT: {e}")
        sys.exit(1)

tnh_fab

tnh_fab

TNH-FAB Command Line Interface

Part of the THICH NHAT HANH SCHOLAR (TNH_SCHOLAR) project. A rapid prototype implementation of the TNH-FAB command-line tool for Open AI based text processing. Provides core functionality for text punctuation, sectioning, translation, and general processing.

DEFAULT_SECTION_PATTERN = 'default_section' module-attribute
DEFAULT_TRANSLATE_PATTERN = 'default_line_translate' module-attribute
logger = get_child_logger(__name__) module-attribute
pass_config = click.make_pass_decorator(TNHFabConfig, ensure=True) module-attribute
TNHFabConfig

Holds configuration for the TNH-FAB CLI tool.

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class TNHFabConfig:
    """Holds configuration for the TNH-FAB CLI tool."""

    def __init__(self):
        self.verbose: bool = False
        self.debug: bool = False
        self.quiet: bool = False
        # Initialize pattern manager with directory set in .env file or default.

        load_dotenv()

        if pattern_path_name := os.getenv("TNH_PATTERN_DIR"):
            pattern_dir = Path(pattern_path_name)
            logger.debug(f"pattern dir: {pattern_path_name}")
        else:
            pattern_dir = TNH_DEFAULT_PATTERN_DIR

        pattern_dir.mkdir(parents=True, exist_ok=True)
        self.pattern_manager = PromptCatalog(pattern_dir)
debug = False instance-attribute
pattern_manager = PromptCatalog(pattern_dir) instance-attribute
quiet = False instance-attribute
verbose = False instance-attribute
__init__()
Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(self):
    self.verbose: bool = False
    self.debug: bool = False
    self.quiet: bool = False
    # Initialize pattern manager with directory set in .env file or default.

    load_dotenv()

    if pattern_path_name := os.getenv("TNH_PATTERN_DIR"):
        pattern_dir = Path(pattern_path_name)
        logger.debug(f"pattern dir: {pattern_path_name}")
    else:
        pattern_dir = TNH_DEFAULT_PATTERN_DIR

    pattern_dir.mkdir(parents=True, exist_ok=True)
    self.pattern_manager = PromptCatalog(pattern_dir)
export_processed_sections(section_result, text_obj)
Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
458
459
460
461
462
463
464
def export_processed_sections(
    section_result: Generator[ProcessedSection, None, None], 
    text_obj: TextObject) -> None:
    click.echo(f"{Frontmatter.generate(text_obj.metadata)}")
    for processed_section in section_result:
        click.echo(processed_section.processed_str)
        click.echo("\n")  # newline separated output. 
gen_text_input(ctx, input_file)

Read input from file or stdin.

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
67
68
69
70
71
72
73
def gen_text_input(ctx: Context, input_file: Optional[Path]) -> TextObject:
    """Read input from file or stdin."""
    if input_file:
        return TextObject.load(input_file)
    if not sys.stdin.isatty():
        return TextObject.from_str(sys.stdin.read())
    raise UsageError("No input provided")
get_pattern(pattern_manager, pattern_name)

Get pattern from the pattern manager.

Parameters:

Name Type Description Default
pattern_manager PromptCatalog

Initialized PatternManager instance

required
pattern_name str

Name of the pattern to load

required

Returns:

Name Type Description
Pattern Prompt

Loaded pattern object

Raises:

Type Description
ClickException

If pattern cannot be loaded

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def get_pattern(pattern_manager: PromptCatalog, pattern_name: str) -> Prompt:
    """
    Get pattern from the pattern manager.

    Args:
        pattern_manager: Initialized PatternManager instance
        pattern_name: Name of the pattern to load

    Returns:
        Pattern: Loaded pattern object

    Raises:
        click.ClickException: If pattern cannot be loaded
    """
    try:
        return pattern_manager.load_pattern(pattern_name)
    except FileNotFoundError as e:
        raise click.ClickException(
            f"Pattern '{pattern_name}' not found in {pattern_manager.base_path}"
        ) from e
    except Exception as e:
        raise click.ClickException(f"Error loading pattern: {e}") from e
main()

Entry point for TNH-FAB CLI tool.

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
466
467
468
def main():
    """Entry point for TNH-FAB CLI tool."""
    tnh_fab()
process(config, input_file, pattern, section, auto, paragraph, template)

Apply custom pattern-based processing to text with flexible structuring options.

This command provides flexible text processing using customizable patterns. It can process text either by sections (defined in a JSON file or auto-detected), by paragraphs, or can be used to process a text as a whole (this is the default). This is particularly useful for formatting, restructuring, or applying consistent transformations to text.

Examples:


# Process using a specific pattern
$ tnh-fab process -p format_xml input.txt


# Process using paragraph mode
$ tnh-fab process -p format_xml -g input.txt


# Process with custom sections
$ tnh-fab process -p format_xml -s sections.json input.txt


# Process with template values
$ tnh-fab process -p format_xml -t template.yaml input.txt

Processing Modes:


1. Single Input Mode (default)
    - Processes entire input.


2. Section Mode (-s):
    - Uses sections from a JSON file
    - Processes each section according to pattern


3. Paragraph Mode (-g):
    - Treats each line/paragraph as a separate unit
    - Useful for simpler processing tasks
    - More memory efficient for large files


3. Auto Section Mode (-a):
    - Automatically sections the input file 
    - Processes by section

 Notes: - Required pattern must exist in pattern directory - Template values can customize pattern behavior

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
@tnh_fab.command()
@click.argument(
    "input_file", type=click.Path(exists=True, path_type=Path), required=False
)
@click.option("-p", "--pattern", required=True, help="Pattern name for processing")
@click.option(
    "-s",
    "--section",
    type=click.Path(exists=True, path_type=Path),
    help="Process using sections from JSON file.",
)
@click.option(
    "-a", "--auto", is_flag=True, help="Automatically generate and process by sections."
)
@click.option("-g", "--paragraph", is_flag=True, help="Process text by paragraphs")
@click.option(
    "-t",
    "--template",
    type=click.Path(exists=True, path_type=Path),
    help="YAML file containing template values",
)
@pass_config
def process(
    config: TNHFabConfig,
    input_file: Optional[Path],
    pattern: str,
    section: Optional[Path],
    auto: bool,
    paragraph: bool,
    template: Optional[Path],
):
    """Apply custom pattern-based processing to text with flexible structuring options.

    This command provides flexible text processing using customizable patterns. It can
    process text either by sections (defined in a JSON file or auto-detected), by
    paragraphs, or can be used to process a text as a whole (this is the default).
    This is particularly useful for formatting, restructuring, or applying
    consistent transformations to text.

    Examples:

        \b
        # Process using a specific pattern
        $ tnh-fab process -p format_xml input.txt

        \b
        # Process using paragraph mode
        $ tnh-fab process -p format_xml -g input.txt

        \b
        # Process with custom sections
        $ tnh-fab process -p format_xml -s sections.json input.txt

        \b
        # Process with template values
        $ tnh-fab process -p format_xml -t template.yaml input.txt


    Processing Modes:

        \b
        1. Single Input Mode (default)
            - Processes entire input.

        \b
        2. Section Mode (-s):
            - Uses sections from a JSON file
            - Processes each section according to pattern

        \b
        3. Paragraph Mode (-g):
            - Treats each line/paragraph as a separate unit
            - Useful for simpler processing tasks
            - More memory efficient for large files

        \b
        3. Auto Section Mode (-a):
            - Automatically sections the input file 
            - Processes by section

    \b
    Notes:
        - Required pattern must exist in pattern directory
        - Template values can customize pattern behavior

    """
    text_obj = gen_text_input(click, input_file)  # type: ignore

    process_pattern = get_pattern(config.pattern_manager, pattern)

    template_dict: Dict[str, str] = {}

    if paragraph:
        result = process_text_by_paragraphs(
            text_obj, template_dict, pattern=process_pattern
        )
        export_processed_sections(result, text_obj)        
    elif section is not None:  # Section mode (either file or auto-generate)    
        text_obj = TextObject.from_section_file(section, text_obj.content)

        result = process_text_by_sections(
            text_obj, template_dict, pattern=process_pattern
        )
        export_processed_sections(result, text_obj)
    elif auto:
        # Auto-generate sections     
        default_section_pattern = get_pattern(
            config.pattern_manager, DEFAULT_SECTION_PATTERN
        )
        text_obj = find_sections(text_obj, section_pattern=default_section_pattern)

        result = process_text_by_sections(
            text_obj, template_dict, pattern=process_pattern
        )
        export_processed_sections(result, text_obj)

    else:
        result = process_text(
            text_obj, pattern=process_pattern, template_dict=template_dict
        )
        click.echo(result)
punctuate(input_file, language, style, review_count, pattern)

[DEPRECATED] Punctuation command is deprecated.

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
@tnh_fab.command()
@click.argument(
    "input_file", type=click.Path(exists=True, path_type=Path), required=False
)
@click.option("-l", "--language", help="[DEPRECATED] Source language code")
@click.option("-y", "--style", help="[DEPRECATED] Punctuation style")
@click.option("-c", "--review-count", type=int, 
              help="[DEPRECATED] Number of review passes")
@click.option("-p", "--pattern", help="[DEPRECATED] Pattern name for punctuation")
def punctuate(
    input_file: Optional[Path],
    language: Optional[str],
    style: Optional[str],
    review_count: Optional[int], 
    pattern: Optional[str],
):
    """[DEPRECATED] Punctuation command is deprecated."""
    click.echo(
        "\nDEPRECATED: The 'punctuate' command is deprecated.\n"
        "Please use: tnh-fab process -p <punctuation_pattern>\n\n"
        "Example:\n"
        "  tnh-fab process -p default_punctuate input.txt\n"
    )
    sys.exit(1)
section(config, input_file, language, num_sections, review_count, pattern)

Analyze and divide text into logical sections based on content.

This command processes the input text to identify coherent sections based on content analysis. It generates a structured representation of the text with sections that maintain logical continuity. Each section includes metadata such as title and line range.

Examples:


# Auto-detect sections in a file
$ tnh-fab section input.txt


# Specify desired number of sections
$ tnh-fab section -n 5 input.txt


# Process Vietnamese text with custom pattern
$ tnh-fab section -l vi -p custom_section_pattern input.txt


# Section text from stdin with increased review
$ cat input.txt | tnh-fab section -c 5

 Output Format: JSON object containing: - language: Detected or specified language code - sections: Array of section objects, each with: - title: Section title in original language - start_line: Starting line number (inclusive) - end_line: Ending line number (inclusive)

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
@tnh_fab.command()
@click.argument(
    "input_file", type=click.Path(exists=True, path_type=Path), required=False
)
@click.option(
    "-l",
    "--language",
    help="Source language code (e.g., 'en', 'vi'). Auto-detected if not specified.",
)
@click.option(
    "-n",
    "--num-sections",
    type=int,
    help="Target number of sections (auto-calculated if not specified)",
)
@click.option(
    "-c",
    "--review-count",
    type=int,
    default=3,
    help="Number of review passes (default: 3)",
)
@click.option(
    "-p",
    "--pattern",
    default=DEFAULT_SECTION_PATTERN,
    help=f"Pattern name for section analysis (default: '{DEFAULT_SECTION_PATTERN}')",
)
@pass_config
def section(
    config: TNHFabConfig,
    input_file: Optional[Path],
    language: Optional[str],
    num_sections: Optional[int],
    review_count: int,
    pattern: str,
):
    """Analyze and divide text into logical sections based on content.

    This command processes the input text to identify coherent sections based on content
    analysis. It generates a structured representation of the text with sections that
    maintain logical continuity. Each section includes metadata such as title and line
    range.

    Examples:

        \b
        # Auto-detect sections in a file
        $ tnh-fab section input.txt

        \b
        # Specify desired number of sections
        $ tnh-fab section -n 5 input.txt

        \b
        # Process Vietnamese text with custom pattern
        $ tnh-fab section -l vi -p custom_section_pattern input.txt

        \b
        # Section text from stdin with increased review
        $ cat input.txt | tnh-fab section -c 5

    \b
    Output Format:
        JSON object containing:
        - language: Detected or specified language code
        - sections: Array of section objects, each with:
            - title: Section title in original language
            - start_line: Starting line number (inclusive)
            - end_line: Ending line number (inclusive)
    """
    input_text = gen_text_input(click, input_file)  # type: ignore
    section_pattern = get_pattern(config.pattern_manager, pattern)

    text_object = find_sections(
        input_text,
        section_pattern=section_pattern,
        section_count=num_sections,
        review_count=review_count,
    )
    # For prototype, just output the JSON representation
    info = text_object.export_info(input_file)
    click.echo(info.model_dump_json(indent=2))
tnh_fab(ctx, verbose, debug, quiet)

TNH-FAB: Thich Nhat Hanh Scholar Text processing command-line tool.

CORE COMMANDS: punctuate, section, translate, process

To Get help on any command and see its options:

tnh-fab [COMMAND] --help

Provides specialized processing for multi-lingual Dharma content.

Offers functionalities for punctuation, sectioning, line-based translation, and general text processing based on predefined patterns. Input text can be provided either via a file or standard input.

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
@click.group()
@click.option("-v", "--verbose", is_flag=True, 
              help="Enable detailed logging. (NOT implemented)")
@click.option("--debug", is_flag=True, help="Enable debug output")
@click.option("--quiet", is_flag=True, help="Suppress all non-error output")
@click.pass_context
def tnh_fab(ctx: Context, verbose: bool, debug: bool, quiet: bool):
    """TNH-FAB: Thich Nhat Hanh Scholar Text processing command-line tool.

    CORE COMMANDS: punctuate, section, translate, process

    To Get help on any command and see its options:

    tnh-fab [COMMAND] --help

    Provides specialized processing for multi-lingual Dharma content.

    Offers functionalities for punctuation, sectioning, line-based translation,
    and general text processing based on predefined patterns.
    Input text can be provided either via a file or standard input.
    """        
    config = ctx.ensure_object(TNHFabConfig)

    if not check_openai_env():

        raise click.ClickException("Missing OpenAI Credentials.")

    config.verbose = verbose
    config.debug = debug
    config.quiet = quiet

    if not quiet:
        if debug:
            setup_logging(log_level=logging.DEBUG)
        else:
            setup_logging(log_level=logging.INFO)
translate(config, input_file, language, target, style, context_lines, segment_size, pattern)

Translate text while preserving line numbers and contextual understanding.

This command performs intelligent translation that maintains line number correspondence between source and translated text. It uses surrounding context to improve translation accuracy and consistency, particularly important for texts where terminology and context are crucial.

Examples:


# Translate Vietnamese text to English
$ tnh-fab translate -l vi input.txt


# Translate to French with specific style
$ tnh-fab translate -l vi -r fr -y "Formal" input.txt


# Translate with increased context
$ tnh-fab translate --context-lines 5 input.txt


# Translate using custom segment size
$ tnh-fab translate --segment-size 10 input.txt

 Notes: - Line numbers are preserved in the output - Context lines are used to improve translation accuracy - Segment size affects processing speed and memory usage

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
@tnh_fab.command()
@click.argument(
    "input_file", type=click.Path(exists=True, path_type=Path), required=False
)
@click.option(
    "-l", "--language", help="Source language code. Auto-detected if not specified."
)
@click.option(
    "-r", "--target", default="en", help="Target language code (default: 'en')"
)
@click.option(
    "-y", "--style", help="Translation style (e.g., 'American Dharma Teaching')"
)
@click.option(
    "--context-lines",
    type=int,
    default=3,
    help="Number of context lines to consider (default: 3)",
)
@click.option(
    "--segment-size",
    type=int,
    help="Lines per translation segment (auto-calculated if not specified)",
)
@click.option(
    "-p",
    "--pattern",
    default=DEFAULT_TRANSLATE_PATTERN,
    help=f"Pattern name for translation (default: '{DEFAULT_TRANSLATE_PATTERN}')",
)
@pass_config
def translate(
    config: TNHFabConfig,
    input_file: Optional[Path],
    language: Optional[str],
    target: str,
    style: Optional[str],
    context_lines: int,
    segment_size: Optional[int],
    pattern: str,
):
    """Translate text while preserving line numbers and contextual understanding.

    This command performs intelligent translation that maintains 
    line number correspondence between source and translated text. 
    It uses surrounding context to improve translation
    accuracy and consistency, particularly important for texts 
    where terminology and context are crucial.

    Examples:

        \b
        # Translate Vietnamese text to English
        $ tnh-fab translate -l vi input.txt

        \b
        # Translate to French with specific style
        $ tnh-fab translate -l vi -r fr -y "Formal" input.txt

        \b
        # Translate with increased context
        $ tnh-fab translate --context-lines 5 input.txt

        \b
        # Translate using custom segment size
        $ tnh-fab translate --segment-size 10 input.txt

    \b
    Notes:
        - Line numbers are preserved in the output
        - Context lines are used to improve translation accuracy
        - Segment size affects processing speed and memory usage
    """
    text_obj = gen_text_input(click, input_file)  # type: ignore
    translation_pattern = get_pattern(config.pattern_manager, pattern)

    text_obj.update_metadata(source_file=input_file)

    text_obj = translate_text_by_lines(
        text_obj,
        source_language=language,
        target_language=target,
        pattern=translation_pattern,
        style=style,
        context_lines=context_lines,
        segment_size=segment_size,
    )
    click.echo(text_obj)

tnh_setup

tnh_setup
OPENAI_ENV_HELP_MSG = "\n>>>>>>>>>> OpenAI API key not found in environment. <<<<<<<<<\n\nFor AI processing with TNH-scholar:\n\n1. Get an API key from https://platform.openai.com/api-keys\n2. Set the OPENAI_API_KEY environment variable:\n\n export OPENAI_API_KEY='your-api-key-here' # Linux/Mac\n set OPENAI_API_KEY=your-api-key-here # Windows\n\nFor OpenAI API access help: https://platform.openai.com/\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>> -- <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n" module-attribute
PATTERNS_URL = 'https://github.com/aaronksolomon/patterns/archive/main.zip' module-attribute
create_config_dirs()

Create required configuration directories.

Source code in src/tnh_scholar/cli_tools/tnh_setup/tnh_setup.py
39
40
41
42
43
def create_config_dirs():
    """Create required configuration directories."""
    TNH_CONFIG_DIR.mkdir(parents=True, exist_ok=True)
    TNH_LOG_DIR.mkdir(exist_ok=True)
    TNH_DEFAULT_PATTERN_DIR.mkdir(exist_ok=True)
download_patterns()

Download and extract pattern files from GitHub.

Source code in src/tnh_scholar/cli_tools/tnh_setup/tnh_setup.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def download_patterns() -> bool:
    """Download and extract pattern files from GitHub."""
    try:
        response = requests.get(PATTERNS_URL)
        response.raise_for_status()

        with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref:
            root_dir = zip_ref.filelist[0].filename.split('/')[0]

            for zip_info in zip_ref.filelist:
                if zip_info.filename.endswith('.md'):
                    rel_path = Path(zip_info.filename).relative_to(root_dir)
                    target_path = TNH_DEFAULT_PATTERN_DIR / rel_path

                    target_path.parent.mkdir(parents=True, exist_ok=True)

                    with zip_ref.open(zip_info) as source, open(target_path, 'wb') as target:
                        target.write(source.read())
        return True

    except Exception as e:
        click.echo(f"Pattern download failed: {e}", err=True)
        return False
main()

Entry point for setup CLI tool.

Source code in src/tnh_scholar/cli_tools/tnh_setup/tnh_setup.py
97
98
99
def main():
    """Entry point for setup CLI tool."""
    tnh_setup()
tnh_setup(skip_env, skip_patterns)

Set up TNH Scholar configuration.

Source code in src/tnh_scholar/cli_tools/tnh_setup/tnh_setup.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@click.command()
@click.option('--skip-env', is_flag=True, help='Skip API key setup')
@click.option('--skip-patterns', is_flag=True, help='Skip pattern download')
def tnh_setup(skip_env: bool, skip_patterns: bool):
    """Set up TNH Scholar configuration."""
    click.echo("Setting up TNH Scholar...")

    # Create config directories
    create_config_dirs()
    click.echo(f"Created config directory: {TNH_CONFIG_DIR}")

    # Pattern download
    if not skip_patterns and click.confirm(
                "\nDownload pattern (markdown text) files from GitHub?\n"
                f"Source: {PATTERNS_URL}\n"
                f"Target: {TNH_DEFAULT_PATTERN_DIR}"
            ):
        if download_patterns():
            click.echo("Pattern files downloaded successfully")
        else:
            click.echo("Pattern download failed", err=True)

    # Environment test:
    if not skip_env:
        load_dotenv()  # for development
        if not check_openai_env(output=False):
            print(OPENAI_ENV_HELP_MSG)

tnh_tree

Developer tool for the tnh-scholar project.

This script generates a directory tree for the entire project and for the src directory, saving the results to 'project_directory_tree.txt' and 'src_directory_tree.txt' respectively.

Uses the generic module generate_tree which has a basic function build_tree that executes tree building.

Exposed as a script via pyproject.toml under the name 'tnh-tree'.

main()

CLI entry point registered as tnh-tree.

Source code in src/tnh_scholar/cli_tools/tnh_tree.py
15
16
17
def main() -> None:
    """CLI entry point registered as ``tnh-tree``."""
    build_tree(TNH_PROJECT_ROOT_DIR, TNH_PROJECT_ROOT_DIR / "src")

token_count

token_count
main()

Entry point for the token-count CLI tool.

Source code in src/tnh_scholar/cli_tools/token_count/token_count.py
15
16
17
def main():
    """Entry point for the token-count CLI tool."""
    token_count_cli()
token_count_cli(input_file)

Return the Open AI API token count of a text file. Based on gpt-4o.

Source code in src/tnh_scholar/cli_tools/token_count/token_count.py
 6
 7
 8
 9
10
11
12
@click.command()
@click.argument("input_file", type=click.File("r"), default="-")
def token_count_cli(input_file):
    """Return the Open AI API token count of a text file. Based on gpt-4o."""
    text = input_file.read()
    result = token_count(text)
    click.echo(result)

ytt_fetch

__all__ = ['main', 'ytt_fetch'] module-attribute
main()
Source code in src/tnh_scholar/cli_tools/ytt_fetch/ytt_fetch.py
166
167
def main():
    ytt_fetch()
ytt_fetch

Simple CLI tool for retrieving video transcripts.

This module provides a command line interface for downloading video transcripts in specified languages. It uses yt-dlp for video info extraction.

logger = get_child_logger(__name__) module-attribute
cleanup_files(keep, filepath)
Source code in src/tnh_scholar/cli_tools/ytt_fetch/ytt_fetch.py
154
155
156
157
def cleanup_files(keep: bool, filepath: Path) -> None:
    if not keep:
        filepath.unlink()
        logger.debug(f"Removed local data file: {filepath}")
export_data(output_path, data)
Source code in src/tnh_scholar/cli_tools/ytt_fetch/ytt_fetch.py
159
160
161
162
163
164
def export_data(output_path, data):
    if output_path:
            write_str_to_file(output_path, data, overwrite=True)
            click.echo(f"Data written to: {output_path}")
    else:
        click.echo(data)
export_ttml_data(metadata, ttml_path, no_embed, output_path, keep)
Source code in src/tnh_scholar/cli_tools/ytt_fetch/ytt_fetch.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def export_ttml_data(
    metadata: Metadata, 
    ttml_path: Optional[Path], 
    no_embed: bool, 
    output_path: Optional[Path], 
    keep: bool):
    try:
        # export transcript as text 
        if ttml_path:
            transcript_text = extract_text_from_ttml(ttml_path)
        else:
            click.echo("Transcript Error. No ttml file found.")
            sys.exit(1)

        if not no_embed:
            transcript_text = Frontmatter.embed(metadata, transcript_text)

        export_data(output_path, transcript_text)   
        cleanup_files(keep, ttml_path)

    except FileNotFoundError as e:
        click.echo(f"File not found error: {e}", err=True)
        sys.exit(1)
    except (IOError, OSError) as e:
        click.echo(f"Error writing transcript to file: {e}", err=True)
        sys.exit(1)
    except TypeError as e:
        click.echo(f"Type error: {e}", err=True)
        sys.exit(1)
generate_metadata(dl, url, keep, output_path)
Source code in src/tnh_scholar/cli_tools/ytt_fetch/ytt_fetch.py
75
76
77
78
79
80
81
82
83
84
def generate_metadata(
    dl: DLPDownloader, 
    url: str, 
    keep: bool,
    output_path: Optional[Path]
    ) -> None:
    metadata = dl.get_metadata(url)
    metadata_out = metadata.text_embed("") # Only metadata

    export_data(output_path, metadata_out)
generate_transcript(dl, url, lang, keep, no_embed, output_path)
Source code in src/tnh_scholar/cli_tools/ytt_fetch/ytt_fetch.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def generate_transcript(
    dl: DLPDownloader, 
    url: str, 
    lang: str, 
    keep: bool, 
    no_embed: bool,
    output_path: Optional[Path]
    ) -> None:

    metadata, ttml_path = get_ttml_download(dl, url, lang, output_path)

    process_metadata = ProcessMetadata(
            step="generate_transcript",
            processor="DLPDownloader",
            tool="ytt-fetch"
            )
    if output_path:
        process_metadata.update(output_path=output_path)

    metadata.add_process_info(process_metadata)

    export_ttml_data(metadata, ttml_path, no_embed, output_path, keep)
get_ttml_download(dl, url, lang, output_path)
Source code in src/tnh_scholar/cli_tools/ytt_fetch/ytt_fetch.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def get_ttml_download(dl, url, lang, output_path):
    try:
        transcript_data = dl.get_transcript(url, lang, output_path)
        metadata = transcript_data.metadata
        ttml_path = transcript_data.filepath

    except TranscriptError as e:
        click.echo(f"Transcript error {e}", err=True)
        sys.exit(1)
    except yt_dlp.utils.DownloadError as e:
        click.echo(f"Failed to extract video transcript: {e}", err=True)
        sys.exit(1)   

    return metadata, ttml_path
main()
Source code in src/tnh_scholar/cli_tools/ytt_fetch/ytt_fetch.py
166
167
def main():
    ytt_fetch()
ytt_fetch(url, lang, keep, info, no_embed, output)

YouTube Transcript Fetch: Retrieve and save transcripts for a Youtube video using yt-dlp.

Source code in src/tnh_scholar/cli_tools/ytt_fetch/ytt_fetch.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
@click.command()
@click.argument("url")
@click.option(
    "-l", "--lang", default="en", help="Language code for transcript (default: en)"
)
@click.option(
    "-k", "--keep",
    is_flag=True,
    help="Keep downloaded datafile: TTML transcript."
)
@click.option(
    "-i", "--info",
    is_flag=True,
    help="Return only metadata in YAML frontmatter format." 
    )
@click.option(
    "-n", "--no-embed",
    is_flag=True,
    help="Do not embed metadata in transcript file."
)
@click.option(
    "-o",
    "--output",
    type=click.Path(),
    help="Save transcript text to file instead of printing.",
)
def ytt_fetch(
    url: str, 
    lang: str, 
    keep: bool, 
    info: bool,
    no_embed: bool,
    output: Optional[str]) -> None:
    """
    YouTube Transcript Fetch: Retrieve and 
    save transcripts for a Youtube video using yt-dlp.
    """

    dl = DLPDownloader()

    output_path = Path(output) if output else None

    if not info:  
        generate_transcript(dl, url, lang, keep, no_embed, output_path)
    else:
        generate_metadata(dl, url, keep, output_path)

exceptions

__all__ = ['TnhScholarError', 'ConfigurationError', 'ValidationError', 'ExternalServiceError', 'RateLimitError', 'NotRetryable'] module-attribute

ConfigurationError

Bases: TnhScholarError

Configuration-related errors (missing env vars, invalid settings, etc.).

Source code in src/tnh_scholar/exceptions.py
33
34
class ConfigurationError(TnhScholarError):
    """Configuration-related errors (missing env vars, invalid settings, etc.)."""

ExternalServiceError

Bases: TnhScholarError

Upstream/provider errors (HTTP 5xx, transport, transient provider issues).

Source code in src/tnh_scholar/exceptions.py
41
42
class ExternalServiceError(TnhScholarError):
    """Upstream/provider errors (HTTP 5xx, transport, transient provider issues)."""

NotRetryable

Bases: TnhScholarError

Marker for errors where retry is known to be pointless (e.g., bad auth).

Source code in src/tnh_scholar/exceptions.py
49
50
class NotRetryable(TnhScholarError):
    """Marker for errors where retry is known to be pointless (e.g., bad auth)."""

RateLimitError

Bases: ExternalServiceError

Upstream rate limits; typically retryable after a backoff.

Source code in src/tnh_scholar/exceptions.py
45
46
class RateLimitError(ExternalServiceError):
    """Upstream rate limits; typically retryable after a backoff."""

TnhScholarError

Bases: Exception

Base exception for all tnh_scholar errors.

Attributes:

Name Type Description
message

Human-readable summary.

context

Optional structured context to aid logging/diagnostics. Keep this JSON-serializable.

cause

Optional underlying exception.

Source code in src/tnh_scholar/exceptions.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class TnhScholarError(Exception):
    """Base exception for all tnh_scholar errors.

    Attributes:
        message: Human-readable summary.
        context: Optional structured context to aid logging/diagnostics.
                 Keep this JSON-serializable.
        cause:   Optional underlying exception.
    """
    def __init__(
        self,
        message: str = "",
        *,
        context: Optional[Mapping[str, Any]] = None,
        cause: Optional[BaseException] = None,
    ) -> None:
        super().__init__(message)
        self.message = message
        self.context = dict(context) if context else {}
        self.__cause__ = cause  # preserves exception chaining

    def __str__(self) -> str:
        return self.message or self.__class__.__name__
__cause__ = cause instance-attribute
context = dict(context) if context else {} instance-attribute
message = message instance-attribute
__init__(message='', *, context=None, cause=None)
Source code in src/tnh_scholar/exceptions.py
17
18
19
20
21
22
23
24
25
26
27
def __init__(
    self,
    message: str = "",
    *,
    context: Optional[Mapping[str, Any]] = None,
    cause: Optional[BaseException] = None,
) -> None:
    super().__init__(message)
    self.message = message
    self.context = dict(context) if context else {}
    self.__cause__ = cause  # preserves exception chaining
__str__()
Source code in src/tnh_scholar/exceptions.py
29
30
def __str__(self) -> str:
    return self.message or self.__class__.__name__

ValidationError

Bases: TnhScholarError

Input/data validation errors (precondition failures before calling providers).

Source code in src/tnh_scholar/exceptions.py
37
38
class ValidationError(TnhScholarError):
    """Input/data validation errors (precondition failures before calling providers)."""

journal_processing

journal_process

BATCH_RETRY_DELAY = 5 module-attribute
DEFAULT_JOURNAL_MODEL = 'gpt-4o' module-attribute
DEFAULT_MODEL_SETTINGS = {'gpt-4o': {'max_tokens': 16000, 'temperature': 1.0}, 'gpt-3.5-turbo': {'max_tokens': 4096, 'temperature': 1.0}, 'gpt-4o-mini': {'max_tokens': 16000, 'temperature': 1.0}} module-attribute
MAX_BATCH_RETRIES = 40 module-attribute
MAX_TOKEN_LIMIT = 60000 module-attribute
journal_schema = {'type': 'object', 'properties': {'journal_summary': {'type': 'string'}, 'sections': {'type': 'array', 'items': {'type': 'object', 'properties': {'title_vi': {'type': 'string'}, 'title_en': {'type': 'string'}, 'author': {'type': ['string', 'null']}, 'summary': {'type': 'string'}, 'keywords': {'type': 'array', 'items': {'type': 'string'}}, 'start_page': {'type': 'integer', 'minimum': 1}, 'end_page': {'type': 'integer', 'minimum': 1}}, 'required': ['title_vi', 'title_en', 'summary', 'keywords', 'start_page', 'end_page']}}}, 'required': ['journal_summary', 'sections']} module-attribute
logger = logging.getLogger('journal_process') module-attribute
ModelSettings

Bases: TypedDict

Source code in src/tnh_scholar/journal_processing/journal_process.py
25
26
27
class ModelSettings(TypedDict):
    max_tokens: int
    temperature: float
max_tokens instance-attribute
temperature instance-attribute
batch_section(input_xml_path, batch_jsonl, system_message, journal_name)

Splits the journal content into sections using GPT, with retries for both starting and completing the batch.

Parameters:

Name Type Description Default
input_xml_path str

Path to the input XML file.

required
output_json_path str

Path to save validated metadata JSON.

required
raw_output_path str

Path to save the raw batch results.

required
journal_name str

Name of the journal being processed.

required
max_retries int

Maximum number of retries for batch processing.

required
retry_delay int

Delay in seconds between retries.

required

Returns:

Name Type Description
str

the result of the batch sectioning process as a serialized json object.

Source code in src/tnh_scholar/journal_processing/journal_process.py
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
def batch_section(
    input_xml_path: Path, batch_jsonl: Path, system_message, journal_name
):
    """
    Splits the journal content into sections using GPT, with retries for both starting and completing the batch.

    Args:
        input_xml_path (str): Path to the input XML file.
        output_json_path (str): Path to save validated metadata JSON.
        raw_output_path (str): Path to save the raw batch results.
        journal_name (str): Name of the journal being processed.
        max_retries (int): Maximum number of retries for batch processing.
        retry_delay (int): Delay in seconds between retries.

    Returns:
        str: the result of the batch sectioning process as a serialized json object.
    """
    try:
        logger.info(
            f"Starting sectioning batch for {journal_name} with file:\n\t{input_xml_path}"
        )
        # Load journal content
        journal_pages = read_str_from_file(input_xml_path)

        # Create GPT messages for sectioning
        user_message_wrapper = lambda text: f"{text}"
        messages = generate_messages(
            system_message, user_message_wrapper, [journal_pages]
        )

        # Create JSONL file for batch processing
        jsonl_file = create_jsonl_file_for_batch(messages, batch_jsonl, json_mode=True)

    except Exception as e:
        logger.error(
            f"Failed to initialize batch sectioning data for journal '{journal_name}'.",
            extra={"input_xml_path": input_xml_path},
            exc_info=True,
        )
        raise RuntimeError(
            f"Error initializing batch sectioning data for journal '{journal_name}'."
        ) from e

    response = start_batch_with_retries(
        jsonl_file,
        description=f"Batch for sectioning journal: {journal_name} | input file: {input_xml_path}",
    )

    if response:
        json_result = response[
            0
        ]  # should return json, just one batch so first response
        # Log success and return output json
        logger.info(
            f"Successfully batch sectioned journal '{journal_name}' with input file: {input_xml_path}."
        )
        return json_result
    else:
        logger.error("Section batch failed to get response.")
        return ""
batch_translate(input_xml_path, batch_json_path, metadata_path, system_message, journal_name)

Translates the journal sections using the GPT model. Saves the translated content back to XML.

Parameters:

Name Type Description Default
input_xml_path str

Path to the input XML file.

required
metadata_path str

Path to the metadata JSON file.

required
journal_name str

Name of the journal.

required
xml_output_path str

Path to save the translated XML.

required
max_retries int

Maximum number of retries for batch operations.

required
retry_delay int

Delay in seconds between retries.

required

Returns:

Name Type Description
bool

True if the process succeeds, False otherwise.

Source code in src/tnh_scholar/journal_processing/journal_process.py
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
def batch_translate(
    input_xml_path: Path,
    batch_json_path: Path,
    metadata_path: Path,
    system_message,
    journal_name: str,
):
    """
    Translates the journal sections using the GPT model.
    Saves the translated content back to XML.

    Args:
        input_xml_path (str): Path to the input XML file.
        metadata_path (str): Path to the metadata JSON file.
        journal_name (str): Name of the journal.
        xml_output_path (str): Path to save the translated XML.
        max_retries (int): Maximum number of retries for batch operations.
        retry_delay (int): Delay in seconds between retries.

    Returns:
        bool: True if the process succeeds, False otherwise.
    """
    logger.info(
        f"Starting translation batch for journal '{journal_name}':\n\twith file: {input_xml_path}\n\tmetadata: {metadata_path}"
    )

    # Data initialization:
    try:
        # load metadata
        serial_json = read_str_from_file(metadata_path)

        section_metadata = deserialize_json(serial_json)
        if not section_metadata:
            raise RuntimeError(f"Metadata could not be loaded from {metadata_path}.")

        # Extract page groups and split XML content
        page_groups = extract_page_groups_from_metadata(section_metadata)
        xml_content = read_str_from_file(input_xml_path)
        section_contents = split_xml_on_pagebreaks(xml_content, page_groups)

        if section_contents:
            logger.debug(f"section_contents[0]:\n{section_contents[0]}")
        else:
            logger.error("No sectin contents.")

    except Exception as e:
        logger.error(
            f"Failed to initialize data for translation batching for journal '{journal_name}'.",
            exc_info=True,
        )
        raise RuntimeError(
            f"Error during data initialization for journal '{journal_name}'."
        ) from e

    translation_data = translate_sections(
        batch_json_path,
        system_message,
        section_contents,
        section_metadata,
        journal_name,
    )
    return translation_data
create_jsonl_file_for_batch(messages, output_file_path=None, max_token_list=None, model=DEFAULT_JOURNAL_MODEL, tools=None, json_mode=False)

Write a JSONL batch file mirroring the legacy OpenAI format.

Source code in src/tnh_scholar/journal_processing/journal_process.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def create_jsonl_file_for_batch(
    messages: list[list[dict[str, str]]],
    output_file_path: Path | str | None = None,
    max_token_list: list[int] | None = None,
    model: str = DEFAULT_JOURNAL_MODEL,
    tools=None,
    json_mode: bool | None = False,
):
    """Write a JSONL batch file mirroring the legacy OpenAI format."""
    model_settings = _get_model_settings(model)
    if not max_token_list:
        max_tokens = model_settings["max_tokens"]
        max_token_list = [max_tokens] * len(messages)

    temperature = model_settings["temperature"]
    total_tokens = sum(max_token_list)

    if output_file_path is None:
        date_str = datetime.now().strftime("%m%d%Y")
        resolved_output = Path(f"batch_requests_{date_str}.jsonl")
    else:
        resolved_output = Path(output_file_path)

    output_dir = resolved_output.parent
    output_dir.mkdir(parents=True, exist_ok=True)

    requests: list[dict[str, Any]] = []
    for i, message in enumerate(messages):
        max_tokens = max_token_list[i]
        request_obj: dict[str, Any] = {
            "custom_id": f"request-{i+1}",
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {
                "model": model,
                "messages": message,
                "max_tokens": max_tokens,
                "temperature": temperature,
            },
        }
        if json_mode:
            request_obj["body"]["response_format"] = {"type": "json_object"}
        if tools:
            request_obj["body"]["tools"] = tools

        requests.append(request_obj)

    with resolved_output.open("w", encoding="utf-8") as handle:
        for request in requests:
            json.dump(request, handle)
            handle.write("\n")

    logger.info(
        "JSONL batch file created at %s with ~%s requested tokens.",
        resolved_output,
        total_tokens,
    )
    return resolved_output
deserialize_json(serialized_data)

Converts a serialized JSON string into a Python dictionary.

Parameters:

Name Type Description Default
serialized_data str

The JSON string to deserialize.

required

Returns:

Name Type Description
dict

The deserialized Python dictionary.

Source code in src/tnh_scholar/journal_processing/journal_process.py
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
def deserialize_json(serialized_data: str):
    """
    Converts a serialized JSON string into a Python dictionary.

    Args:
        serialized_data (str): The JSON string to deserialize.

    Returns:
        dict: The deserialized Python dictionary.
    """
    if not isinstance(serialized_data, str):
        logger.error(
            f"String input required for deserialize_json. Received: {type(serialized_data)}"
        )
        raise ValueError("String input required.")

    try:
        # Convert the JSON string into a dictionary
        return json.loads(serialized_data)
    except json.JSONDecodeError as e:
        logger.error(f"Failed to deserialize JSON: {e}")
        raise
extract_page_groups_from_metadata(metadata)

Extracts page groups from the section metadata for use with split_xml_pages.

Parameters:

Name Type Description Default
metadata dict

The section metadata containing sections with start and end pages.

required

Returns:

Type Description

List[Tuple[int, int]]: A list of tuples, each representing a page range (start_page, end_page).

Source code in src/tnh_scholar/journal_processing/journal_process.py
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
def extract_page_groups_from_metadata(metadata):
    """
    Extracts page groups from the section metadata for use with `split_xml_pages`.

    Parameters:
        metadata (dict): The section metadata containing sections with start and end pages.

    Returns:
        List[Tuple[int, int]]: A list of tuples, each representing a page range (start_page, end_page).
    """
    page_groups = []

    # Ensure metadata contains sections
    if "sections" not in metadata or not isinstance(metadata["sections"], list):
        raise ValueError(
            "Metadata does not contain a valid 'sections' key with a list of sections."
        )

    for section in metadata["sections"]:
        try:
            # Extract start and end pages
            start_page = section.get("start_page")
            end_page = section.get("end_page")

            # Ensure both start_page and end_page are integers
            if not isinstance(start_page, int) or not isinstance(end_page, int):
                raise ValueError(f"Invalid page range in section: {section}")

            # Add the tuple to the page groups list
            page_groups.append((start_page, end_page))

        except KeyError as e:
            print(f"Missing key in section metadata: {e}")
        except ValueError as e:
            print(f"Error processing section metadata: {e}")

    logger.debug(f"page groups found: {page_groups}")

    return page_groups
generate_all_batches(processed_document_dir, system_message, user_wrap_function, file_regex='.*\\.xml')

Generate cleaning batches for all journals in the specified directory.

Parameters:

Name Type Description Default
processed_journals_dir str

Path to the directory containing processed journal data.

required
system_message str

System message template for batch processing.

required
user_wrap_function callable

Function to wrap user input for processing pages.

required
file_regex str

Regex pattern to identify target files (default: ".*.xml").

'.*\\.xml'

Returns:

Type Description

None

Source code in src/tnh_scholar/journal_processing/journal_process.py
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
def generate_all_batches(
    processed_document_dir: str,
    system_message: str,
    user_wrap_function,
    file_regex: str = r".*\.xml",
):
    """
    Generate cleaning batches for all journals in the specified directory.

    Parameters:
        processed_journals_dir (str): Path to the directory containing processed journal data.
        system_message (str): System message template for batch processing.
        user_wrap_function (callable): Function to wrap user input for processing pages.
        file_regex (str): Regex pattern to identify target files (default: ".*\\.xml").

    Returns:
        None
    """
    logger = logging.getLogger(__name__)
    document_dir = Path(processed_document_dir)
    regex = re.compile(file_regex)

    for journal_file in document_dir.iterdir():
        if journal_file.is_file() and regex.search(journal_file.name):
            try:
                # Derive output file path
                output_file = journal_file.with_suffix(".jsonl")
                logger.info(f"Generating batch for {journal_file}...")

                # Call single batch function
                generate_single_oa_batch_from_pages(
                    input_xml_file=str(journal_file),
                    output_file=str(output_file),
                    system_message=system_message,
                    user_wrap_function=user_wrap_function,
                )
            except Exception as e:
                logger.error(f"Failed to process {journal_file}: {e}")
                continue

    logger.info("Batch generation completed.")
generate_clean_batch(input_xml_file, output_file, system_message, user_wrap_function)

Generate a batch file for the OpenAI (OA) API using a single input XML file.

Parameters:

Name Type Description Default
batch_file str

Full path to the input XML file to process.

required
output_file str

Full path to the output batch JSONL file.

required
system_message str

System message template for batch processing.

required
user_wrap_function callable

Function to wrap user input for processing pages.

required

Returns:

Name Type Description
str

Path to the created batch file.

Raises:

Type Description
Exception

If an error occurs during file processing.

Source code in src/tnh_scholar/journal_processing/journal_process.py
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
def generate_clean_batch(
    input_xml_file: str, output_file: str, system_message: str, user_wrap_function
):
    """
    Generate a batch file for the OpenAI (OA) API using a single input XML file.

    Parameters:
        batch_file (str): Full path to the input XML file to process.
        output_file (str): Full path to the output batch JSONL file.
        system_message (str): System message template for batch processing.
        user_wrap_function (callable): Function to wrap user input for processing pages.

    Returns:
        str: Path to the created batch file.

    Raises:
        Exception: If an error occurs during file processing.
    """

    try:
        # Read the OCR text from the batch file
        text = read_str_from_file(input_xml_file)
        logger.info(f"Processing file: {input_xml_file}")

        # Split the text into pages for processing
        pages = split_xml_on_pagebreaks(text)
        pages = wrap_all_lines(pages)  # wrap lines with brackets.
        if not pages:
            raise ValueError(f"No pages found in XML file: {input_xml_file}")
        logger.info(f"Found {len(pages)} pages in {input_xml_file}.")

        max_tokens = [_get_max_tokens_for_clean(page) for page in pages]

        # Generate messages for the pages
        batch_message_seq = generate_messages(system_message, user_wrap_function, pages)

        # Save the batch file
        create_jsonl_file_for_batch(
            batch_message_seq, output_file, max_token_list=max_tokens
        )
        logger.info(f"Batch file created successfully: {output_file}")

        return output_file

    except FileNotFoundError:
        logger.error("File not found.")
        raise
    except ValueError as e:
        logger.error(f"Value error: {e}")
        raise
    except Exception as e:
        logger.error(f"Unexpected error while processing {input_xml_file}: {e}")
        raise
generate_messages(system_message, user_message_wrapper, data_list_to_process, log_system_message=True)

Build OpenAI-style chat message payloads.

Source code in src/tnh_scholar/journal_processing/journal_process.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def generate_messages(
    system_message: str,
    user_message_wrapper: Callable[[object], str],
    data_list_to_process: Sequence[object],
    log_system_message: bool = True,
) -> list[list[dict[str, str]]]:
    """Build OpenAI-style chat message payloads."""
    if log_system_message:
        logger.debug("System message:\n%s", system_message)

    messages = []
    for data_element in data_list_to_process:
        message_block = [
            {"role": "system", "content": system_message},
            {"role": "user", "content": user_message_wrapper(data_element)},
        ]
        messages.append(message_block)
    return messages
generate_single_oa_batch_from_pages(input_xml_file, output_file, system_message, user_wrap_function)

*** Depricated *** Generate a batch file for the OpenAI (OA) API using a single input XML file.

Parameters:

Name Type Description Default
batch_file str

Full path to the input XML file to process.

required
output_file str

Full path to the output batch JSONL file.

required
system_message str

System message template for batch processing.

required
user_wrap_function callable

Function to wrap user input for processing pages.

required

Returns:

Name Type Description
str

Path to the created batch file.

Raises:

Type Description
Exception

If an error occurs during file processing.

Source code in src/tnh_scholar/journal_processing/journal_process.py
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
def generate_single_oa_batch_from_pages(
    input_xml_file: str,
    output_file: str,
    system_message: str,
    user_wrap_function,
):
    """
    *** Depricated ***
    Generate a batch file for the OpenAI (OA) API using a single input XML file.

    Parameters:
        batch_file (str): Full path to the input XML file to process.
        output_file (str): Full path to the output batch JSONL file.
        system_message (str): System message template for batch processing.
        user_wrap_function (callable): Function to wrap user input for processing pages.

    Returns:
        str: Path to the created batch file.

    Raises:
        Exception: If an error occurs during file processing.
    """
    logger = logging.getLogger(__name__)

    try:
        # Read the OCR text from the batch file
        text = read_str_from_file(input_xml_file)
        logger.info(f"Processing file: {input_xml_file}")

        # Split the text into pages for processing
        pages = split_xml_pages(text)
        if not pages:
            raise ValueError(f"No pages found in XML file: {input_xml_file}")
        logger.info(f"Found {len(pages)} pages in {input_xml_file}.")

        # Generate messages for the pages
        batch_message_seq = generate_messages(system_message, user_wrap_function, pages)

        # Save the batch file
        create_jsonl_file_for_batch(batch_message_seq, output_file)
        logger.info(f"Batch file created successfully: {output_file}")

        return output_file

    except FileNotFoundError:
        logger.error(f"File not found: {input_xml_file}")
        raise
    except ValueError as e:
        logger.error(f"Value error: {e}")
        raise
    except Exception as e:
        logger.error(f"Unexpected error while processing {input_xml_file}: {e}")
        raise
run_immediate_chat_process(messages, max_tokens=0, response_format=None, model=DEFAULT_JOURNAL_MODEL)

Legacy-compatible immediate completion powered by GenAI simple_completion.

Source code in src/tnh_scholar/journal_processing/journal_process.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def run_immediate_chat_process(
    messages: list[dict[str, str]],
    max_tokens: int = 0,
    response_format=None,
    model: str = DEFAULT_JOURNAL_MODEL,
):
    """Legacy-compatible immediate completion powered by GenAI simple_completion."""
    system_message, user_message = _extract_message_parts(messages)
    if not max_tokens:
        max_tokens = _get_model_settings(model)["max_tokens"]

    return simple_completion(
        system_message=system_message,
        user_message=user_message,
        model=model,
        max_tokens=max_tokens,
    )
save_cleaned_data(cleaned_xml_path, cleaned_wrapped_pages, journal_name)
Source code in src/tnh_scholar/journal_processing/journal_process.py
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
def save_cleaned_data(
    cleaned_xml_path: Path, cleaned_wrapped_pages: List[str], journal_name
):
    try:
        logger.info(f"Saving cleaned content to XML for journal '{journal_name}'.")
        cleaned_wrapped_pages = unwrap_all_lines(cleaned_wrapped_pages)
        save_pages_to_xml(cleaned_xml_path, cleaned_wrapped_pages, overwrite=True)
        logger.info(f"Cleaned journal saved successfully to:\n\t{cleaned_xml_path}")
    except Exception as e:
        logger.error(
            f"Failed to save cleaned data for journal '{journal_name}'.",
            extra={"cleaned_xml_path": cleaned_xml_path},
            exc_info=True,
        )
        raise RuntimeError(
            f"Failed to save cleaned data for journal '{journal_name}'."
        ) from e
save_sectioning_data(output_json_path, raw_output_path, serial_json, journal_name)
Source code in src/tnh_scholar/journal_processing/journal_process.py
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
def save_sectioning_data(
    output_json_path: Path, raw_output_path: Path, serial_json: str, journal_name
):
    try:
        raw_output_path.write_text(serial_json, encoding="utf-8")
    except Exception as e:
        logger.error(
            f"Failed to write raw response file for journal '{journal_name}'.",
            extra={"raw_output_path": raw_output_path},
            exc_info=True,
        )
        raise RuntimeError(
            f"Failed to write raw response file for journal '{journal_name}'."
        ) from e

    # Validate and save metadata
    try:
        valid = validate_and_save_metadata(
            output_json_path, serial_json, journal_schema
        )
        if not valid:
            raise RuntimeError(
                f"Validation failed for metadata of journal '{journal_name}'."
            )
    except Exception as e:
        logger.error(
            f"Error occurred while validating and saving metadata for journal '{journal_name}'.",
            extra={"output_json_path": output_json_path},
            exc_info=True,
        )
        raise RuntimeError(f"Validation error for journal '{journal_name}'.") from e

    return output_json_path
save_translation_data(xml_output_path, translation_data, journal_name)
Source code in src/tnh_scholar/journal_processing/journal_process.py
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
def save_translation_data(xml_output_path: Path, translation_data, journal_name):
    # Save translated content back to XML
    try:
        logger.info(f"Saving translated content to XML for journal '{journal_name}'.")
        join_xml_data_to_doc(xml_output_path, translation_data, overwrite=True)
        logger.info(f"Translated journal saved successfully to:\n\t{xml_output_path}")

    except Exception as e:
        logger.error(
            f"Failed to save translation data for journal '{journal_name}'.",
            extra={"xml_output_path": xml_output_path},
            exc_info=True,
        )
        raise RuntimeError(
            f"Failed to save translation data for journal '{journal_name}'."
        ) from e
send_data_for_tx_batch(batch_jsonl_path, section_data_to_send, system_message, max_token_list, journal_name, immediate=False)

Sends data for translation batch or immediate processing.

Parameters:

Name Type Description Default
batch_jsonl_path Path

Path for the JSONL file to save batch data.

required
section_data_to_send List

List of section data to translate.

required
system_message str

System message for the translation process.

required
max_token_list List

List of max tokens for each section.

required
journal_name str

Name of the journal being processed.

required
immediate bool

If True, run immediate chat processing instead of batch.

False

Returns:

Name Type Description
List

Translated data from the batch or immediate process.

Source code in src/tnh_scholar/journal_processing/journal_process.py
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
def send_data_for_tx_batch(
    batch_jsonl_path: Path,
    section_data_to_send: List,
    system_message,
    max_token_list: List,
    journal_name,
    immediate=False,
):
    """
    Sends data for translation batch or immediate processing.

    Args:
        batch_jsonl_path (Path): Path for the JSONL file to save batch data.
        section_data_to_send (List): List of section data to translate.
        system_message (str): System message for the translation process.
        max_token_list (List): List of max tokens for each section.
        journal_name (str): Name of the journal being processed.
        immediate (bool): If True, run immediate chat processing instead of batch.

    Returns:
        List: Translated data from the batch or immediate process.
    """
    try:
        # Generate all messages using the generate_messages function
        user_message_wrapper = (
            lambda section_info: f"Translate this section with title '{section_info.title}':\n{section_info.content}"
        )
        messages = generate_messages(
            system_message, user_message_wrapper, section_data_to_send
        )

        if immediate:
            logger.info(f"Running immediate chat process for journal '{journal_name}'.")
            translated_data = []
            for i, message in enumerate(messages):
                max_tokens = max_token_list[i]
                response = run_immediate_chat_process(message, max_tokens=max_tokens)
                translated_data.append(response)
            logger.info(
                f"Immediate translation completed for journal '{journal_name}'."
            )
            return translated_data
        else:
            logger.info(f"Running batch processing for journal '{journal_name}'.")
            # Create batch file for batch processing
            jsonl_file = create_jsonl_file_for_batch(
                messages, batch_jsonl_path, max_token_list=max_token_list
            )
            if not jsonl_file:
                raise RuntimeError("Failed to create JSONL file for translation batch.")

            # Process batch and return the result
            translation_data = start_batch_with_retries(
                jsonl_file,
                description=f"Batch for translating journal '{journal_name}'",
            )
            logger.info(f"Batch translation completed for journal '{journal_name}'.")
            return translation_data

    except Exception as e:
        logger.error(
            f"Error during translation processing for journal '{journal_name}'.",
            exc_info=True,
        )
        raise RuntimeError("Error in translation process.") from e
setup_logger(log_file_path)

Configures the logger to write to a log file and the console. Adds a custom "PRIORITY_INFO" logging level for important messages.

Source code in src/tnh_scholar/journal_processing/journal_process.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def setup_logger(log_file_path):
    """
    Configures the logger to write to a log file and the console.
    Adds a custom "PRIORITY_INFO" logging level for important messages.
    """
    # Remove existing handlers
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)

    logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",  # Include logger name
        handlers=[
            logging.FileHandler(log_file_path, encoding="utf-8"),
            logging.StreamHandler(),  # Optional: to log to the console as well
        ],
    )

    # Suppress DEBUG/INFO logs for specific noisy modules
    modules_to_suppress = ["httpx", "httpcore", "urllib3", "openai", "google"]
    for module in modules_to_suppress:
        logger = logging.getLogger(module)
        logger.setLevel(logging.WARNING)  # Suppress DEBUG and INFO logs

    # Add a custom "PRIORITY_INFO" level
    PRIORITY_INFO_LEVEL = 25  # Between INFO (20) and WARNING (30)
    logging.addLevelName(PRIORITY_INFO_LEVEL, "PRIORITY_INFO")

    def priority_info(self, message, *args, **kwargs):
        if self.isEnabledFor(PRIORITY_INFO_LEVEL):
            self._log(PRIORITY_INFO_LEVEL, f"\033[93m{message}\033[0m", args, **kwargs)

    logging.Logger.priority_info = priority_info

    return logging.getLogger(__name__)
start_batch_with_retries(jsonl_file, description='', max_retries=MAX_BATCH_RETRIES, retry_delay=BATCH_RETRY_DELAY, poll_interval=10, timeout=3600)

Simulate the legacy batch runner using sequential simple_completion calls.

The parameters mirror the old interface so callers remain unchanged, but the implementation now iterates through the JSONL requests locally.

Source code in src/tnh_scholar/journal_processing/journal_process.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def start_batch_with_retries(
    jsonl_file: Path,
    description: str = "",
    max_retries: int = MAX_BATCH_RETRIES,
    retry_delay: int = BATCH_RETRY_DELAY,
    poll_interval: int = 10,
    timeout: int = 3600,
) -> list[str]:
    """
    Simulate the legacy batch runner using sequential simple_completion calls.

    The parameters mirror the old interface so callers remain unchanged, but the
    implementation now iterates through the JSONL requests locally.
    """
    logger.info(
        "Running sequential batch for '%s' using %s",
        description,
        jsonl_file,
    )
    responses: list[str] = []
    try:
        with jsonl_file.open("r", encoding="utf-8") as handle:
            for line_no, line in enumerate(handle, start=1):
                if not line.strip():
                    continue
                payload = json.loads(line)
                body = payload.get("body", {})
                request_model = body.get("model", DEFAULT_JOURNAL_MODEL)
                messages = body.get("messages", [])
                max_tokens = body.get("max_tokens") or body.get("max_completion_tokens")
                if not max_tokens:
                    max_tokens = _get_model_settings(request_model)["max_tokens"]
                system_message, user_message = _extract_message_parts(messages)
                response = simple_completion(
                    system_message=system_message,
                    user_message=user_message,
                    model=request_model,
                    max_tokens=max_tokens,
                )
                responses.append(response)
                logger.debug("Processed request %s from batch file", line_no)

    except Exception as exc:
        logger.error(
            "Failed to process batch '%s' from %s",
            description or jsonl_file,
            jsonl_file,
            exc_info=True,
        )
        raise RuntimeError("Failed to process batch sequentially") from exc

    logger.info(
        "Sequential batch for '%s' completed with %s responses.",
        description or jsonl_file,
        len(responses),
    )
    return responses
translate_sections(batch_jsonl_path, system_message, section_contents, section_metadata, journal_name, immediate=False)

build up sections in batches to translate

Source code in src/tnh_scholar/journal_processing/journal_process.py
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
def translate_sections(
    batch_jsonl_path: Path,
    system_message,
    section_contents,
    section_metadata,
    journal_name,
    immediate=False,
):
    """build up sections in batches to translate"""

    section_mdata = section_metadata["sections"]
    if len(section_contents) != len(section_mdata):
        raise RuntimeError("Section length mismatch.")

    # collate metadata and section content, calculate max_tokens per section:
    section_data_to_send = []
    max_token_list = []
    current_token_count = 0
    collected_translations = []
    section_last_index = len(section_mdata) - 1

    for i, section_info in enumerate(section_mdata):
        section_content = section_contents[i]
        max_tokens = floor(token_count(section_content) * 1.3) + 1000
        max_token_list.append(max_tokens)
        current_token_count += max_tokens
        section_data = SimpleNamespace(
            title=section_info["title_en"], content=section_content
        )
        section_data_to_send.append(section_data)
        logger.debug(f"section {i}: {section_data.title} added for batch processing.")

        if current_token_count >= MAX_TOKEN_LIMIT or i == section_last_index:
            # send sections for batch processing since token limit reached.
            batch_result = send_data_for_tx_batch(
                batch_jsonl_path,
                section_data_to_send,
                system_message,
                max_token_list,
                journal_name,
                immediate,
            )
            collected_translations.extend(batch_result)

            # reset containers to start building up next batch.
            section_data_to_send = []
            max_token_list = []
            current_token_count = 0

    return collected_translations
unwrap_all_lines(pages)
Source code in src/tnh_scholar/journal_processing/journal_process.py
335
336
337
338
339
340
341
342
def unwrap_all_lines(pages):
    result = []
    for page in pages:
        if page == "blank page":
            result.append(page)
        else:
            result.append(unwrap_lines(page))
    return result
unwrap_lines(text)
Removes angle brackets (< >) from encapsulated lines and merges them into
a newline-separated string.

Parameters:
    text (str): The input string with encapsulated lines.

Returns:
    str: A newline-separated string with the encapsulation removed.

Example:
    >>> merge_encapsulated_lines("<Line 1> <Line 2> <Line 3>")
    'Line 1

Line 2 Line 3' >>> merge_encapsulated_lines(" ") 'Line 1 Line 2 Line 3'

Source code in src/tnh_scholar/journal_processing/journal_process.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def unwrap_lines(text: str) -> str:
    """
    Removes angle brackets (< >) from encapsulated lines and merges them into
    a newline-separated string.

    Parameters:
        text (str): The input string with encapsulated lines.

    Returns:
        str: A newline-separated string with the encapsulation removed.

    Example:
        >>> merge_encapsulated_lines("<Line 1> <Line 2> <Line 3>")
        'Line 1\nLine 2\nLine 3'
        >>> merge_encapsulated_lines("<Line 1>\n<Line 2>\n<Line 3>")
        'Line 1\nLine 2\nLine 3'
    """
    # Find all content between < and > using regex
    matches = re.findall(r"<(.*?)>", text)
    # Join the extracted content with newlines
    return "\n".join(matches)
validate_and_clean_data(data, schema)

Recursively validate and clean AI-generated data to fit the given schema. Any missing fields are filled with defaults, and extra fields are ignored.

Parameters:

Name Type Description Default
data dict

The AI-generated data to validate and clean.

required
schema dict

The schema defining the required structure.

required

Returns:

Name Type Description
dict

The cleaned data adhering to the schema.

Source code in src/tnh_scholar/journal_processing/journal_process.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
def validate_and_clean_data(data, schema):
    """
    Recursively validate and clean AI-generated data to fit the given schema.
    Any missing fields are filled with defaults, and extra fields are ignored.

    Args:
        data (dict): The AI-generated data to validate and clean.
        schema (dict): The schema defining the required structure.

    Returns:
        dict: The cleaned data adhering to the schema.
    """

    def clean_value(value, field_schema):
        """
        Clean a single value based on its schema, attempting type conversions where necessary.
        """
        field_type = field_schema["type"]

        # Handle type: string
        if field_type == "string":
            if isinstance(value, str):
                return value
            elif value is not None:
                return str(value)
            return "unset"

        # Handle type: integer
        elif field_type == "integer":
            if isinstance(value, int):
                return value
            elif isinstance(value, str) and value.isdigit():
                return int(value)
            try:
                return int(float(value))  # Handle cases like "2.0"
            except (ValueError, TypeError):
                return 0

        # Handle type: array
        elif field_type == "array":
            if isinstance(value, list):
                item_schema = field_schema.get("items", {})
                return [clean_value(item, item_schema) for item in value]
            elif isinstance(value, str):
                # Try splitting comma-separated strings into a list
                return [v.strip() for v in value.split(",")]
            return []

        # Handle type: object
        elif field_type == "object":
            if isinstance(value, dict):
                return validate_and_clean_data(value, field_schema)
            return {}

        # Handle nullable strings
        elif field_type == ["string", "null"]:
            if value is None or isinstance(value, str):
                return value
            return str(value)

        # Default case for unknown or unsupported types
        return "unset"

    def clean_object(obj, obj_schema):
        """
        Clean a dictionary object based on its schema.
        """
        if not isinstance(obj, dict):
            print(
                f"Expected dict but got: \n{type(obj)}: {obj}\nResetting to empty dict."
            )
            return {}
        cleaned = {}
        properties = obj_schema.get("properties", {})
        for key, field_schema in properties.items():
            # Set default value for missing fields
            cleaned[key] = clean_value(obj.get(key), field_schema)
        return cleaned

    # Handle the top-level object
    if schema["type"] == "object":
        cleaned_data = clean_object(data, schema)
        return cleaned_data
    else:
        raise ValueError("Top-level schema must be of type 'object'.")
validate_and_save_metadata(output_file_path, json_metadata_serial, schema)

Validates and cleans journal data against the schema, then writes it to a JSON file.

Parameters:

Name Type Description Default
data str

The journal data as a serialized JSON string to validate and clean.

required
schema dict

The schema defining the required structure.

required
output_file_path str

Path to the output JSON file.

required

Returns:

Name Type Description
bool

True if successfully written to the file, False otherwise.

Source code in src/tnh_scholar/journal_processing/journal_process.py
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
def validate_and_save_metadata(
    output_file_path: Path, json_metadata_serial: str, schema
):
    """
    Validates and cleans journal data against the schema, then writes it to a JSON file.

    Args:
        data (str): The journal data as a serialized JSON string to validate and clean.
        schema (dict): The schema defining the required structure.
        output_file_path (str): Path to the output JSON file.

    Returns:
        bool: True if successfully written to the file, False otherwise.
    """
    try:
        # Clean the data to fit the schema
        data = deserialize_json(json_metadata_serial)
        cleaned_data = validate_and_clean_data(data, schema)

        # Write the parsed data to the specified JSON file
        with open(output_file_path, "w", encoding="utf-8") as f:
            json.dump(cleaned_data, f, indent=4, ensure_ascii=False)
        logger.info(
            f"Parsed and validated metadata successfully written to {output_file_path}"
        )
        return True
    except Exception as e:
        logger.error(f"An error occurred during validation or writing: {e}")
        raise
wrap_all_lines(pages)
Source code in src/tnh_scholar/journal_processing/journal_process.py
308
309
def wrap_all_lines(pages):
    return [wrap_lines(page) for page in pages]
wrap_lines(text)
Encloses each line of the input text with angle brackets.

Args:
    text (str): The input string containing lines separated by '

'.

Returns:
    str: A string where each line is enclosed in angle brackets.

Example:
    >>> enclose_lines("This is a string with

two lines.") ' < two lines.>'

Source code in src/tnh_scholar/journal_processing/journal_process.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def wrap_lines(text: str) -> str:
    """
    Encloses each line of the input text with angle brackets.

    Args:
        text (str): The input string containing lines separated by '\n'.

    Returns:
        str: A string where each line is enclosed in angle brackets.

    Example:
        >>> enclose_lines("This is a string with   \n   two lines.")
        '<This is a string with  >\n<    two lines.>'
    """
    return "\n".join(f"<{line}>" for line in text.split("\n"))

logging_config

TNH-Scholar Logging Utilities

A production-ready, environment-driven logging system for the TNH-Scholar project. It provides JSON logs in production, color/plain text in development, optional non-blocking queue logging, file rotation, noise suppression for chatty deps, and optional routing of Python warnings into the logging pipeline.

This module is designed for application layer configuration and library layer usage:

  • Applications (CLI, Streamlit, FastAPI, notebooks) call :func:setup_logging.
  • Libraries / services (e.g., gen_ai_service, IssueHandler) only acquire a logger via :func:get_logger (or legacy :func:get_child_logger) and never configure global logging.

Quick start

Application entry point (recommended):

>>> from tnh_scholar.logging_config import setup_logging, get_logger
>>> setup_logging()  # reads env; see variables below
>>> log = get_logger(__name__)
>>> log.info("app started", extra={"service": "gen-ai"})

Jupyter / dev (force color in non-TTY):

>>> import os
>>> os.environ["APP_ENV"] = "dev"
>>> os.environ["LOG_JSON"] = "false"
>>> os.environ["LOG_COLOR"] = "true"]  # Jupyter isn't a TTY; force color
>>> from tnh_scholar.logging_config import setup_logging, get_logger
>>> setup_logging()
>>> get_logger(__name__).info("hello, color")

Library / service modules (do NOT configure logging):

>>> from tnh_scholar.logging_config import get_logger
>>> log = get_logger(__name__)
>>> log.info("library message")

Behavior by environment
  • dev (default):
    • Plain or color text to stdout by default.
    • Queue logging disabled by default (synchronous).
    • Color auto-detects TTY and Jupyter/IPython (can be forced).
  • prod:
    • JSON logs to stderr by default (suitable for log shippers).
    • Queue logging enabled by default (can be disabled).

Environment variables

Most behavior is controlled by environment variables (read when setup_logging() instantiates :class:LogSettings). Truthy values accept true/1/yes/on (case-insensitive).

  • APP_ENV: dev | prod | test (default: dev)
  • LOG_LEVEL: Logging level for the base project logger (default: INFO)
  • LOG_STDOUT: Emit logs to stdout (default: true)
  • LOG_FILE_ENABLE: Emit logs to a file (default: false)
  • LOG_FILE_PATH: File path for logs (default: ./logs/main.log)
  • LOG_ROTATE_BYTES: Rotate at N bytes (e.g., 10485760) (default: unset)
  • LOG_ROTATE_WHEN: Timed rotation (e.g., midnight) (default: unset)
  • LOG_BACKUPS: Number of rotated file backups (default: 5)
  • LOG_JSON: Use JSON formatter (recommended in prod) (default: true)
  • LOG_COLOR: true | false | auto (default: auto)
  • LOG_STREAM: stdout | stderr (default: stderr; dev defaults to stdout)
  • LOG_USE_QUEUE: Use QueueHandler/QueueListener (default: true; dev defaults to false)
  • LOG_CAPTURE_WARNINGS: Route Python warnings via logging (default: false)
  • LOG_SUPPRESS: Comma-separated list of noisy module names to set to WARNING (default includes urllib3, httpx, openai, uvicorn.*, etc.)

Backward compatibility
  • get_child_logger(name, console=False, separate_file=False) remains available and can attach ad-hoc console/file handlers without reconfiguring the project base logger. When custom handlers are attached, the child’s propagation is turned off to avoid duplicate messages.
  • setup_logging_legacy(...) forwards to :func:setup_logging and emits a DeprecationWarning to help locate legacy call sites.
  • Custom level PRIORITY_INFO (25) and :meth:logger.priority_info still exist but are deprecated. Prefer:

    log.info("message", extra={"priority": "high"})

This keeps level semantics standard and plays better with structured logging.


Queue logging notes
  • When LOG_USE_QUEUE=true, the base logger uses a :class:QueueHandler. A :class:QueueListener is started with sinks mirroring your configured stdout/file handlers. This decouples log emission from I/O to minimize latency.
  • In notebooks or during debugging, you may prefer synchronous logs:

    os.environ["LOG_USE_QUEUE"] = "false"


Python warnings routing
  • When LOG_CAPTURE_WARNINGS=true, Python warnings are captured and logged through py.warnings. This module attaches the base logger’s handlers to that logger and disables propagation to avoid duplicate output.

Mixing print() and logging
  • print() writes to stdout; the logger can write to stdout or stderr depending on LOG_STREAM and environment. Ordering is not guaranteed, especially with queue logging enabled. Prefer logging for consistent output.

Minimal examples

CLI / entrypoint:

>>> import os
>>> os.environ.setdefault("APP_ENV", "prod")
>>> os.environ.setdefault("LOG_JSON", "true")
>>> from tnh_scholar.logging_config import setup_logging, get_logger
>>> setup_logging()
>>> get_logger(__name__).info("ready")

File logging with rotation:

>>> import os
>>> os.environ.update({
...     "LOG_FILE_ENABLE": "true",
...     "LOG_FILE_PATH": "./logs/app.log",
...     "LOG_ROTATE_BYTES": "10485760",  # 10MB
...     "LOG_BACKUPS": "7",
... })
>>> setup_logging()
>>> get_logger("smoke").info("to file")

Jupyter with color:

>>> import os
>>> os.environ.update({"APP_ENV": "dev", "LOG_JSON": "false", "LOG_COLOR": "true"})
>>> setup_logging()
>>> get_logger(__name__).info("color in notebook")

Notes
  • JSON formatting requires python-json-logger; without it, we fall back to plain/color format automatically.
  • This module never configures the root logger; it configures the project base logger (tnh) so your app can coexist with other libraries cleanly.

BASE_LOG_DIR = Path('./logs') module-attribute

BASE_LOG_NAME = 'tnh' module-attribute

DEFAULT_CONSOLE_FORMAT_STRING = LOG_FMT_COLOR module-attribute

DEFAULT_FILE_FORMAT_STRING = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' module-attribute

DEFAULT_LOG_FILEPATH = Path('main.log') module-attribute

JsonFormatter = getattr(_pythonjsonlogger_json, 'JsonFormatter', None) module-attribute

LOG_COLORS = {'DEBUG': 'bold_green', 'INFO': 'cyan', 'PRIORITY_INFO': 'bold_cyan', 'WARNING': 'bold_yellow', 'ERROR': 'bold_red', 'CRITICAL': 'bold_red'} module-attribute

LOG_FMT_COLOR = '%(asctime)s | %(log_color)s%(levelname)-8s%(reset)s | %(name)s | %(message)s' module-attribute

LOG_FMT_JSON = '%(asctime)s %(levelname)s %(name)s %(message)s %(process)d %(thread)d %(module)s %(filename)s %(lineno)d' module-attribute

LOG_FMT_PLAIN = '%(asctime)s | %(levelname)-8s | %(name)s | %(message)s' module-attribute

MAX_FILE_SIZE = 10 * 1024 * 1024 module-attribute

PRIORITY_INFO_LEVEL = 25 module-attribute

__all__ = ['BASE_LOG_NAME', 'BASE_LOG_DIR', 'DEFAULT_LOG_FILEPATH', 'MAX_FILE_SIZE', 'OMPFilter', 'setup_logging', 'setup_logging_legacy', 'get_logger', 'get_child_logger'] module-attribute

LogSettings dataclass

Environment-driven logging settings with sensible defaults.

Source code in src/tnh_scholar/logging_config.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
@dataclass
class LogSettings:
    """Environment-driven logging settings with sensible defaults."""
    # Mode
    environment: str = field(default_factory=lambda: _env_str("APP_ENV", "dev"))  # dev|prod|test
    base_name: str = field(default_factory=lambda: _env_str("LOG_BASE", BASE_LOG_NAME))

    # Level
    level: str = field(default_factory=lambda: _env_str("LOG_LEVEL", "INFO"))

    # Outputs
    to_stdout: bool = field(default_factory=lambda: _env_bool("LOG_STDOUT", "true"))
    to_file: bool = field(default_factory=lambda: _env_bool("LOG_FILE_ENABLE", "false"))
    file_path: Path = field(
        default_factory=lambda: Path(
            _env_str("LOG_FILE_PATH", str(BASE_LOG_DIR / DEFAULT_LOG_FILEPATH))
            )
        )

    # File rotation
    rotate_when: Optional[str] = field(default_factory=lambda: _env_str("LOG_ROTATE_WHEN", "") or None)  
        # e.g. 'midnight'
    rotate_bytes: Optional[int] = field(default_factory=lambda: (_env_int("LOG_ROTATE_BYTES", 0) or None))  
        # e.g. 10485760
    backups: int = field(default_factory=lambda: _env_int("LOG_BACKUPS", 5))

    # Format
    json_format: bool = field(default_factory=lambda: _env_bool("LOG_JSON", "true"))  # prod default
    colorize: str = field(default_factory=lambda: _env_str("LOG_COLOR", "auto"))  # true|false|auto

    # Python warnings routing
    capture_warnings: bool = field(default_factory=lambda: _env_bool("LOG_CAPTURE_WARNINGS", "false"))

    # Stream selection (stdout|stderr)
    log_stream: str = field(default_factory=lambda: _env_str("LOG_STREAM", "stderr"))

    # Performance
    use_queue: bool = field(default_factory=lambda: _env_bool("LOG_USE_QUEUE", "true"))

    # Noise suppression (comma-separated)
    suppress_modules: str = field(default_factory=lambda: _env_str(
        "LOG_SUPPRESS",
        "urllib3,httpx,openai,botocore,boto3,asyncio,uvicorn,uvicorn.error,uvicorn.access",
    ))

    def is_dev(self) -> bool:
        return self.environment.lower() == "dev"

    def should_color(self) -> bool:
        if self.colorize == "true":
            return True
        if self.colorize == "false":
            return False
        # auto: TTY or Jupyter/IPython
        if _is_tty(self.selected_stream()):
            return True
        try:
            from IPython.core.getipython import get_ipython
            return get_ipython() is not None  # in a notebook/console
        except Exception:
            return False

    def selected_stream(self):
        """Return the Python stream object to emit logs to (stdout or stderr)."""
        return sys.stdout if self.log_stream.lower() == "stdout" else sys.stderr

    def __post_init__(self):
        # Default to stdout and no-queue in dev, unless explicitly overridden by env
        if self.is_dev():
            if "LOG_STREAM" not in os.environ:
                self.log_stream = "stdout"
            if "LOG_USE_QUEUE" not in os.environ:
                self.use_queue = False
backups = field(default_factory=(lambda: _env_int('LOG_BACKUPS', 5))) class-attribute instance-attribute
base_name = field(default_factory=(lambda: _env_str('LOG_BASE', BASE_LOG_NAME))) class-attribute instance-attribute
capture_warnings = field(default_factory=(lambda: _env_bool('LOG_CAPTURE_WARNINGS', 'false'))) class-attribute instance-attribute
colorize = field(default_factory=(lambda: _env_str('LOG_COLOR', 'auto'))) class-attribute instance-attribute
environment = field(default_factory=(lambda: _env_str('APP_ENV', 'dev'))) class-attribute instance-attribute
file_path = field(default_factory=(lambda: Path(_env_str('LOG_FILE_PATH', str(BASE_LOG_DIR / DEFAULT_LOG_FILEPATH))))) class-attribute instance-attribute
json_format = field(default_factory=(lambda: _env_bool('LOG_JSON', 'true'))) class-attribute instance-attribute
level = field(default_factory=(lambda: _env_str('LOG_LEVEL', 'INFO'))) class-attribute instance-attribute
log_stream = field(default_factory=(lambda: _env_str('LOG_STREAM', 'stderr'))) class-attribute instance-attribute
rotate_bytes = field(default_factory=(lambda: _env_int('LOG_ROTATE_BYTES', 0) or None)) class-attribute instance-attribute
rotate_when = field(default_factory=(lambda: _env_str('LOG_ROTATE_WHEN', '') or None)) class-attribute instance-attribute
suppress_modules = field(default_factory=(lambda: _env_str('LOG_SUPPRESS', 'urllib3,httpx,openai,botocore,boto3,asyncio,uvicorn,uvicorn.error,uvicorn.access'))) class-attribute instance-attribute
to_file = field(default_factory=(lambda: _env_bool('LOG_FILE_ENABLE', 'false'))) class-attribute instance-attribute
to_stdout = field(default_factory=(lambda: _env_bool('LOG_STDOUT', 'true'))) class-attribute instance-attribute
use_queue = field(default_factory=(lambda: _env_bool('LOG_USE_QUEUE', 'true'))) class-attribute instance-attribute
__init__(environment=(lambda: _env_str('APP_ENV', 'dev'))(), base_name=(lambda: _env_str('LOG_BASE', BASE_LOG_NAME))(), level=(lambda: _env_str('LOG_LEVEL', 'INFO'))(), to_stdout=(lambda: _env_bool('LOG_STDOUT', 'true'))(), to_file=(lambda: _env_bool('LOG_FILE_ENABLE', 'false'))(), file_path=(lambda: Path(_env_str('LOG_FILE_PATH', str(BASE_LOG_DIR / DEFAULT_LOG_FILEPATH))))(), rotate_when=(lambda: _env_str('LOG_ROTATE_WHEN', '') or None)(), rotate_bytes=(lambda: _env_int('LOG_ROTATE_BYTES', 0) or None)(), backups=(lambda: _env_int('LOG_BACKUPS', 5))(), json_format=(lambda: _env_bool('LOG_JSON', 'true'))(), colorize=(lambda: _env_str('LOG_COLOR', 'auto'))(), capture_warnings=(lambda: _env_bool('LOG_CAPTURE_WARNINGS', 'false'))(), log_stream=(lambda: _env_str('LOG_STREAM', 'stderr'))(), use_queue=(lambda: _env_bool('LOG_USE_QUEUE', 'true'))(), suppress_modules=(lambda: _env_str('LOG_SUPPRESS', 'urllib3,httpx,openai,botocore,boto3,asyncio,uvicorn,uvicorn.error,uvicorn.access'))())
__post_init__()
Source code in src/tnh_scholar/logging_config.py
343
344
345
346
347
348
349
def __post_init__(self):
    # Default to stdout and no-queue in dev, unless explicitly overridden by env
    if self.is_dev():
        if "LOG_STREAM" not in os.environ:
            self.log_stream = "stdout"
        if "LOG_USE_QUEUE" not in os.environ:
            self.use_queue = False
is_dev()
Source code in src/tnh_scholar/logging_config.py
322
323
def is_dev(self) -> bool:
    return self.environment.lower() == "dev"
selected_stream()

Return the Python stream object to emit logs to (stdout or stderr).

Source code in src/tnh_scholar/logging_config.py
339
340
341
def selected_stream(self):
    """Return the Python stream object to emit logs to (stdout or stderr)."""
    return sys.stdout if self.log_stream.lower() == "stdout" else sys.stderr
should_color()
Source code in src/tnh_scholar/logging_config.py
325
326
327
328
329
330
331
332
333
334
335
336
337
def should_color(self) -> bool:
    if self.colorize == "true":
        return True
    if self.colorize == "false":
        return False
    # auto: TTY or Jupyter/IPython
    if _is_tty(self.selected_stream()):
        return True
    try:
        from IPython.core.getipython import get_ipython
        return get_ipython() is not None  # in a notebook/console
    except Exception:
        return False

LoggingConfigurator

Source code in src/tnh_scholar/logging_config.py
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
class LoggingConfigurator:
    _queue: Optional[queue.Queue] = None

    # ----- Private helpers (handlers) -----
    def _stdout_handler_config(self, fmt_key: str) -> dict:
        stream_path = "ext://sys.stdout" if self.settings.log_stream.lower() == "stdout" else "ext://sys.stderr"
        return {
            "class": "logging.StreamHandler",
            "stream": stream_path,
            "formatter": fmt_key,
            "filters": ["omp_filter"],
        }

    def _file_handler_config(self, *, formatter_key: str) -> dict:
        s = self.settings
        s.file_path.parent.mkdir(parents=True, exist_ok=True)
        if s.rotate_bytes:
            return {
                "class": "logging.handlers.RotatingFileHandler",
                "maxBytes": s.rotate_bytes,
                "backupCount": s.backups,
                "filename": str(s.file_path),
                "formatter": formatter_key,
                "encoding": "utf-8",
                "filters": ["omp_filter"],
            }
        if s.rotate_when:
            return {
                "class": "logging.handlers.TimedRotatingFileHandler",
                "when": s.rotate_when,
                "backupCount": s.backups,
                "filename": str(s.file_path),
                "formatter": formatter_key,
                "encoding": "utf-8",
                "filters": ["omp_filter"],
            }
        return {
            "class": "logging.FileHandler",
            "filename": str(s.file_path),
            "formatter": formatter_key,
            "encoding": "utf-8",
            "filters": ["omp_filter"],
        }
    """Modular builder for project-wide logging configuration."""

    def __init__(self, settings: Optional[LogSettings] = None):
        self.settings = settings or LogSettings()
        # persistent queue instance for QueueHandler/Listener pairing
        self._queue = queue.Queue() if self.settings.use_queue else None

    # ----- Legacy-args bridge -----
    def apply_legacy_args(
        self,
        *,
        log_level,
        log_filepath,
        max_log_file_size,
        backup_count,
        console,
    ) -> None:
        s = self.settings
        s.level = (logging.getLevelName(log_level) if isinstance(log_level, int) else str(log_level)).upper()
        if console is False:
            s.to_stdout = False
            s.to_file = True
        if log_filepath != DEFAULT_LOG_FILEPATH:
            s.to_file = True
            s.file_path = BASE_LOG_DIR / Path(log_filepath)
        if max_log_file_size and max_log_file_size != MAX_FILE_SIZE:
            s.rotate_bytes = int(max_log_file_size)
        s.backups = backup_count or s.backups

    # ----- Builders -----
    def build_formatters(self) -> dict:
        s = self.settings
        fmts: dict[str, dict] = {}
        if s.json_format and JsonFormatter is not None:
            fmts["json"] = {
                "()": "pythonjsonlogger.json.JsonFormatter",
                "fmt": LOG_FMT_JSON,
                "json_ensure_ascii": False,
            }
        else:
            fmts["plain"] = {
                "()": f"{__name__}.UtcFormatter",
                "fmt": LOG_FMT_PLAIN,
            }
            if s.is_dev() and colorlog and s.should_color():
                fmts["color"] = {
                    "()": "colorlog.ColoredFormatter",
                    "format": LOG_FMT_COLOR,
                    "log_colors": LOG_COLORS,
                }
        return fmts

    def build_filters(self) -> dict:
        return {"omp_filter": {"()": f"{__name__}.OMPFilter"}}

    def build_handlers(self, formatters: dict) -> dict:
        s = self.settings
        handlers: dict[str, dict] = {}

        # stdout handler
        if s.to_stdout:
            if s.json_format and JsonFormatter is not None:
                fmt = "json"
            elif s.is_dev() and colorlog and s.should_color():
                fmt = "color"
            else:
                fmt = "plain"
            handlers["stdout"] = self._stdout_handler_config(fmt)

        # file handler
        formatter_key = "json" if (s.json_format and JsonFormatter is not None) else "plain"
        if s.to_file:
            handlers["file"] = self._file_handler_config(formatter_key=formatter_key)

        # queue wrapper
        if s.use_queue and handlers:
            if self._queue is None:
                self._queue = queue.Queue()
            handlers["queue"] = {
                "class": "logging.handlers.QueueHandler",
                "queue": self._queue,
            }
        return handlers

    # ----- Private helpers (queue sinks) -----
    def _make_stream_sink(self) -> logging.Handler:
        s = self.settings
        sh = logging.StreamHandler(self.settings.selected_stream())
        if s.json_format and JsonFormatter is not None:
            sh.setFormatter(JsonFormatter(LOG_FMT_JSON))
        elif s.is_dev() and colorlog and s.should_color():
            sh.setFormatter(colorlog.ColoredFormatter(LOG_FMT_COLOR, log_colors=LOG_COLORS))
        else:
            sh.setFormatter(UtcFormatter(LOG_FMT_PLAIN))
        sh.addFilter(OMPFilter())
        return sh

    def _make_file_sink(self) -> logging.Handler:
        s = self.settings
        if s.rotate_bytes:
            fh: logging.Handler = RotatingFileHandler(
                str(s.file_path),
                maxBytes=s.rotate_bytes,
                backupCount=s.backups,
                encoding="utf-8",
            )
        elif s.rotate_when:
            fh = TimedRotatingFileHandler(
                str(s.file_path),
                when=s.rotate_when,
                backupCount=s.backups,
                encoding="utf-8",
            )
        else:
            fh = logging.FileHandler(str(s.file_path), encoding="utf-8")

        if s.json_format and JsonFormatter is not None:
            fh.setFormatter(JsonFormatter(LOG_FMT_JSON))
        else:
            fh.setFormatter(UtcFormatter(LOG_FMT_PLAIN))
        fh.addFilter(OMPFilter())
        return fh

    def select_base_handlers(self, handlers: dict) -> list[str]:
        s = self.settings
        base_handlers: list[str] = []
        if s.use_queue and ("queue" in handlers):
            base_handlers.append("queue")
        else:
            if "stdout" in handlers:
                base_handlers.append("stdout")
            if "file" in handlers:
                base_handlers.append("file")
        return base_handlers

    def build_config(self, *, filters: dict, formatters: dict, handlers: dict) -> dict:
        s = self.settings
        return {
            "version": 1,
            "disable_existing_loggers": False,
            "filters": filters,
            "formatters": formatters,
            "handlers": handlers,
            "loggers": {
                s.base_name: {
                    "level": s.level,
                    "handlers": self.select_base_handlers(handlers),
                    "propagate": False,
                }
            },
        }

    def apply_config(self, config: dict) -> None:
        logging.config.dictConfig(config)
        logging.captureWarnings(self.settings.capture_warnings)
        # When routing Python warnings into logging, the records go to 'py.warnings'.
        # Attach our base handlers so warnings are visible.
        if self.settings.capture_warnings:
            base = logging.getLogger(self.settings.base_name)
            pyw = logging.getLogger("py.warnings")
            # Avoid duplicate handlers on re-configure
            existing = {id(h) for h in pyw.handlers}
            for h in base.handlers:
                if id(h) not in existing:
                    pyw.addHandler(h)
            # Ensure records are emitted even if root has no handlers
            pyw.setLevel(logging.WARNING)
            pyw.propagate = False

    def start_queue_listener(self, handlers: dict) -> None:
        global _queue_listener
        s = self.settings
        if not (s.use_queue and ("queue" in handlers)):
            return
        q_logger = logging.getLogger(s.base_name)
        qh = next((h for h in q_logger.handlers if isinstance(h, QueueHandler)), None)
        if qh is None:
            return
        q = qh.queue  # type: ignore[attr-defined]

        sink_handlers: list[logging.Handler] = []
        if "stdout" in handlers:
            sink_handlers.append(self._make_stream_sink())
        if "file" in handlers and s.to_file:
            sink_handlers.append(self._make_file_sink())

        if _queue_listener:
            with contextlib.suppress(Exception):
                _queue_listener.stop()
        _queue_listener = QueueListener(q, *sink_handlers, respect_handler_level=True)
        _queue_listener.start()

    def suppress_noise(self, modules_override, force: bool = False) -> None:
        s = self.settings
        modules = modules_override
        # Normalize to a list of module names (strings)
        if modules is None:
            modules = s.suppress_modules  # env string by default
        if isinstance(modules, str):
            modules_list = [m.strip() for m in modules.split(",") if m.strip()]
        else:
            # Attempt to iterate; if not iterable, coerce to single-item list
            try:
                modules_list = [str(m).strip() for m in modules if str(m).strip()]
            except TypeError:
                modules_list = [str(modules).strip()] if str(modules).strip() else []
        for module in modules_list:
            logger = logging.getLogger(module)
            if force or logger.level == logging.NOTSET:
                logger.setLevel(logging.WARNING)

    # ----- Facade -----
    def configure(
        self,
        *,
        legacy_args: dict,
        suppressed_modules,
    ) -> logging.Logger:
        self.apply_legacy_args(**legacy_args)
        formatters = self.build_formatters()
        filters = self.build_filters()
        handlers = self.build_handlers(formatters)
        config = self.build_config(filters=filters, formatters=formatters, handlers=handlers)
        self.apply_config(config)
        self.start_queue_listener(handlers)
        self.suppress_noise(suppressed_modules, force=False)
        return logging.getLogger(self.settings.base_name)
settings = settings or LogSettings() instance-attribute
__init__(settings=None)
Source code in src/tnh_scholar/logging_config.py
399
400
401
402
def __init__(self, settings: Optional[LogSettings] = None):
    self.settings = settings or LogSettings()
    # persistent queue instance for QueueHandler/Listener pairing
    self._queue = queue.Queue() if self.settings.use_queue else None
apply_config(config)
Source code in src/tnh_scholar/logging_config.py
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
def apply_config(self, config: dict) -> None:
    logging.config.dictConfig(config)
    logging.captureWarnings(self.settings.capture_warnings)
    # When routing Python warnings into logging, the records go to 'py.warnings'.
    # Attach our base handlers so warnings are visible.
    if self.settings.capture_warnings:
        base = logging.getLogger(self.settings.base_name)
        pyw = logging.getLogger("py.warnings")
        # Avoid duplicate handlers on re-configure
        existing = {id(h) for h in pyw.handlers}
        for h in base.handlers:
            if id(h) not in existing:
                pyw.addHandler(h)
        # Ensure records are emitted even if root has no handlers
        pyw.setLevel(logging.WARNING)
        pyw.propagate = False
apply_legacy_args(*, log_level, log_filepath, max_log_file_size, backup_count, console)
Source code in src/tnh_scholar/logging_config.py
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
def apply_legacy_args(
    self,
    *,
    log_level,
    log_filepath,
    max_log_file_size,
    backup_count,
    console,
) -> None:
    s = self.settings
    s.level = (logging.getLevelName(log_level) if isinstance(log_level, int) else str(log_level)).upper()
    if console is False:
        s.to_stdout = False
        s.to_file = True
    if log_filepath != DEFAULT_LOG_FILEPATH:
        s.to_file = True
        s.file_path = BASE_LOG_DIR / Path(log_filepath)
    if max_log_file_size and max_log_file_size != MAX_FILE_SIZE:
        s.rotate_bytes = int(max_log_file_size)
    s.backups = backup_count or s.backups
build_config(*, filters, formatters, handlers)
Source code in src/tnh_scholar/logging_config.py
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
def build_config(self, *, filters: dict, formatters: dict, handlers: dict) -> dict:
    s = self.settings
    return {
        "version": 1,
        "disable_existing_loggers": False,
        "filters": filters,
        "formatters": formatters,
        "handlers": handlers,
        "loggers": {
            s.base_name: {
                "level": s.level,
                "handlers": self.select_base_handlers(handlers),
                "propagate": False,
            }
        },
    }
build_filters()
Source code in src/tnh_scholar/logging_config.py
449
450
def build_filters(self) -> dict:
    return {"omp_filter": {"()": f"{__name__}.OMPFilter"}}
build_formatters()
Source code in src/tnh_scholar/logging_config.py
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
def build_formatters(self) -> dict:
    s = self.settings
    fmts: dict[str, dict] = {}
    if s.json_format and JsonFormatter is not None:
        fmts["json"] = {
            "()": "pythonjsonlogger.json.JsonFormatter",
            "fmt": LOG_FMT_JSON,
            "json_ensure_ascii": False,
        }
    else:
        fmts["plain"] = {
            "()": f"{__name__}.UtcFormatter",
            "fmt": LOG_FMT_PLAIN,
        }
        if s.is_dev() and colorlog and s.should_color():
            fmts["color"] = {
                "()": "colorlog.ColoredFormatter",
                "format": LOG_FMT_COLOR,
                "log_colors": LOG_COLORS,
            }
    return fmts
build_handlers(formatters)
Source code in src/tnh_scholar/logging_config.py
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
def build_handlers(self, formatters: dict) -> dict:
    s = self.settings
    handlers: dict[str, dict] = {}

    # stdout handler
    if s.to_stdout:
        if s.json_format and JsonFormatter is not None:
            fmt = "json"
        elif s.is_dev() and colorlog and s.should_color():
            fmt = "color"
        else:
            fmt = "plain"
        handlers["stdout"] = self._stdout_handler_config(fmt)

    # file handler
    formatter_key = "json" if (s.json_format and JsonFormatter is not None) else "plain"
    if s.to_file:
        handlers["file"] = self._file_handler_config(formatter_key=formatter_key)

    # queue wrapper
    if s.use_queue and handlers:
        if self._queue is None:
            self._queue = queue.Queue()
        handlers["queue"] = {
            "class": "logging.handlers.QueueHandler",
            "queue": self._queue,
        }
    return handlers
configure(*, legacy_args, suppressed_modules)
Source code in src/tnh_scholar/logging_config.py
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
def configure(
    self,
    *,
    legacy_args: dict,
    suppressed_modules,
) -> logging.Logger:
    self.apply_legacy_args(**legacy_args)
    formatters = self.build_formatters()
    filters = self.build_filters()
    handlers = self.build_handlers(formatters)
    config = self.build_config(filters=filters, formatters=formatters, handlers=handlers)
    self.apply_config(config)
    self.start_queue_listener(handlers)
    self.suppress_noise(suppressed_modules, force=False)
    return logging.getLogger(self.settings.base_name)
select_base_handlers(handlers)
Source code in src/tnh_scholar/logging_config.py
520
521
522
523
524
525
526
527
528
529
530
def select_base_handlers(self, handlers: dict) -> list[str]:
    s = self.settings
    base_handlers: list[str] = []
    if s.use_queue and ("queue" in handlers):
        base_handlers.append("queue")
    else:
        if "stdout" in handlers:
            base_handlers.append("stdout")
        if "file" in handlers:
            base_handlers.append("file")
    return base_handlers
start_queue_listener(handlers)
Source code in src/tnh_scholar/logging_config.py
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
def start_queue_listener(self, handlers: dict) -> None:
    global _queue_listener
    s = self.settings
    if not (s.use_queue and ("queue" in handlers)):
        return
    q_logger = logging.getLogger(s.base_name)
    qh = next((h for h in q_logger.handlers if isinstance(h, QueueHandler)), None)
    if qh is None:
        return
    q = qh.queue  # type: ignore[attr-defined]

    sink_handlers: list[logging.Handler] = []
    if "stdout" in handlers:
        sink_handlers.append(self._make_stream_sink())
    if "file" in handlers and s.to_file:
        sink_handlers.append(self._make_file_sink())

    if _queue_listener:
        with contextlib.suppress(Exception):
            _queue_listener.stop()
    _queue_listener = QueueListener(q, *sink_handlers, respect_handler_level=True)
    _queue_listener.start()
suppress_noise(modules_override, force=False)
Source code in src/tnh_scholar/logging_config.py
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
def suppress_noise(self, modules_override, force: bool = False) -> None:
    s = self.settings
    modules = modules_override
    # Normalize to a list of module names (strings)
    if modules is None:
        modules = s.suppress_modules  # env string by default
    if isinstance(modules, str):
        modules_list = [m.strip() for m in modules.split(",") if m.strip()]
    else:
        # Attempt to iterate; if not iterable, coerce to single-item list
        try:
            modules_list = [str(m).strip() for m in modules if str(m).strip()]
        except TypeError:
            modules_list = [str(modules).strip()] if str(modules).strip() else []
    for module in modules_list:
        logger = logging.getLogger(module)
        if force or logger.level == logging.NOTSET:
            logger.setLevel(logging.WARNING)

OMPFilter

Bases: Filter

Source code in src/tnh_scholar/logging_config.py
654
655
656
657
class OMPFilter(logging.Filter):
    def filter(self, record):
        # Suppress messages containing "OMP:"
        return "OMP:" not in record.getMessage()
filter(record)
Source code in src/tnh_scholar/logging_config.py
655
656
657
def filter(self, record):
    # Suppress messages containing "OMP:"
    return "OMP:" not in record.getMessage()

UtcFormatter

Bases: Formatter

UTC ISO-8601 timestamps for plain text logging.

Source code in src/tnh_scholar/logging_config.py
265
266
267
268
269
270
271
272
273
274
class UtcFormatter(logging.Formatter):
    """UTC ISO-8601 timestamps for plain text logging."""
    # logging.Formatter.converter must accept (float | None) and return struct_time;
    # time.gmtime satisfies that contract and returns a UTC struct_time.
    converter = time.gmtime

    def formatTime(self, record, datefmt=None):
        if datefmt:
            return super().formatTime(record, datefmt)
        return datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat()
converter = time.gmtime class-attribute instance-attribute
formatTime(record, datefmt=None)
Source code in src/tnh_scholar/logging_config.py
271
272
273
274
def formatTime(self, record, datefmt=None):
    if datefmt:
        return super().formatTime(record, datefmt)
    return datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat()

get_child_logger(name, console=False, separate_file=False)

Get a child logger that writes logs to a console or a specified file.

Parameters:

Name Type Description Default
name str

The name of the child logger (e.g., module name).

required
console bool

If True, log to the console. If False, do not log to the console. If None, inherit console behavior from the parent logger.

False
file Path

A string specifying a logfile to log to. will be placed

required

Returns:

Type Description

logging.Logger: Configured child logger.

Source code in src/tnh_scholar/logging_config.py
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
def get_child_logger(name: str, console: bool = False, separate_file: bool = False):
    """
    Get a child logger that writes logs to a console or a specified file.

    Args:
        name (str): The name of the child logger (e.g., module name).
        console (bool, optional): If True, log to the console. If False, do not log to the console.
                                  If None, inherit console behavior from the parent logger.
        file (Path, optional): A string specifying a logfile to log to. will be placed 
        under existing root logs directory. If provided,a rotating file handler will be added.

    Returns:
        logging.Logger: Configured child logger.
    """

    def _setup_logfile(name, logger):
        logfile = BASE_LOG_DIR / f"{name}.log"
        logfile.parent.mkdir(parents=True, exist_ok=True)  # Ensure directory exists
        file_handler = RotatingFileHandler(
            filename=str(logfile),
            maxBytes=MAX_FILE_SIZE,  # Use the global MAX_FILE_SIZE
            backupCount=5,
            encoding="utf-8",
        )
        file_formatter = logging.Formatter(DEFAULT_FILE_FORMAT_STRING)
        file_handler.setFormatter(file_formatter)
        logger.addHandler(file_handler)

    # Create the fully qualified child logger name
    full_name = f"{BASE_LOG_NAME}.{name}"
    logger = logging.getLogger(full_name)

    # Check if the logger already has handlers to avoid duplication
    if not logger.handlers:
        # Add console handler if specified
        if console:
            console_handler = colorlog.StreamHandler()
            console_formatter = colorlog.ColoredFormatter(
                DEFAULT_CONSOLE_FORMAT_STRING,
                log_colors=LOG_COLORS,
            )
            console_handler.setFormatter(console_formatter)
            logger.addHandler(console_handler)

        # Add file handler if a file path is provided
        if separate_file:
            _setup_logfile(name, logger)
        # Prevent duplication if we've attached custom handlers
        logger.propagate = not console and not separate_file

    return logger

get_logger(name)

Preferred helper: returns a namespaced logger under the base project name.

Backwards-compatible with existing call sites that used get_child_logger(name).

Source code in src/tnh_scholar/logging_config.py
714
715
716
717
718
719
def get_logger(name: str) -> logging.Logger:
    """Preferred helper: returns a namespaced logger under the base project name.

    Backwards-compatible with existing call sites that used get_child_logger(__name__).
    """
    return logging.getLogger(f"{BASE_LOG_NAME}.{name}")

priority_info(self, message, *args, **kwargs)

Deprecated: use logger.info(msg, extra={"priority": "high"}) instead.

This custom level (25) was introduced for highlighting important informational events, but it complicates interoperability with external log shippers and structured log processing. The recommended migration path is to log at the standard INFO level with an added extra field indicating priority.

Example

logger.info("Important event", extra={"priority": "high"})

Source code in src/tnh_scholar/logging_config.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def priority_info(self, message, *args, **kwargs):
    """
    Deprecated: use `logger.info(msg, extra={"priority": "high"})` instead.

    This custom level (25) was introduced for highlighting important informational
    events, but it complicates interoperability with external log shippers and
    structured log processing. The recommended migration path is to log at the
    standard INFO level with an added `extra` field indicating priority.

    Example:
        >>> logger.info("Important event", extra={"priority": "high"})
    """
    warnings.warn(
        "logger.priority_info() is deprecated; use logger.info(..., extra={'priority': 'high'}) instead.",
        DeprecationWarning,
        stacklevel=2,
    )
    if self.isEnabledFor(PRIORITY_INFO_LEVEL):
        # Log normally at PRIORITY_INFO_LEVEL for backward compatibility
        self._log(PRIORITY_INFO_LEVEL, message, args, **kwargs)
    else:
        # Fallback to standard INFO level if not explicitly handled
        self.info(message, *args, **kwargs)

setup_logging(log_level=logging.INFO, log_filepath=DEFAULT_LOG_FILEPATH, max_log_file_size=MAX_FILE_SIZE, backup_count=5, console=True, suppressed_modules=None, *, settings=None)

Initialize project-wide logging using dictConfig, with JSON in prod and colorized/plain text in dev.

Backward compatible with previous signature. Prefer using env vars or pass a LogSettings via the keyword-only settings parameter.

Source code in src/tnh_scholar/logging_config.py
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
def setup_logging(
    log_level=logging.INFO,
    log_filepath=DEFAULT_LOG_FILEPATH,
    max_log_file_size=MAX_FILE_SIZE,  # 10MB
    backup_count=5,
    console=True,
    suppressed_modules=None,
    *,
    settings: "LogSettings|None" = None,
) -> logging.Logger:
    """
    Initialize project-wide logging using dictConfig, with JSON in prod and colorized/plain text in dev.

    Backward compatible with previous signature. Prefer using env vars or pass a LogSettings via the
    keyword-only `settings` parameter.
    """
    global _queue_listener
    configurator = LoggingConfigurator(settings=settings)
    legacy_args = {
        "log_level": log_level,
        "log_filepath": log_filepath,
        "max_log_file_size": max_log_file_size,
        "backup_count": backup_count,
        "console": console,
    }
    return configurator.configure(legacy_args=legacy_args, suppressed_modules=suppressed_modules)

setup_logging_legacy(*args, **kwargs)

Deprecated: use setup_logging().

This wrapper preserves old call sites during migration. It emits a DeprecationWarning (once per process) and forwards all arguments to the current setup_logging().

Source code in src/tnh_scholar/logging_config.py
722
723
724
725
726
727
728
729
730
731
732
733
def setup_logging_legacy(*args, **kwargs) -> logging.Logger:
    """Deprecated: use setup_logging().

    This wrapper preserves old call sites during migration. It emits a DeprecationWarning
    (once per process) and forwards all arguments to the current setup_logging().
    """
    warnings.warn(
        "setup_logging_legacy() is deprecated; migrate to setup_logging() and get_logger().",
        DeprecationWarning,
        stacklevel=2,
    )
    return setup_logging(*args, **kwargs)

metadata

metadata

JsonValue = Union[str, int, float, bool, list, dict, None] module-attribute
logger = get_child_logger(__name__) module-attribute
Frontmatter

Handles YAML frontmatter embedding and extraction.

Source code in src/tnh_scholar/metadata/metadata.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
class Frontmatter:
    """Handles YAML frontmatter embedding and extraction."""
    @staticmethod
    def extract(content: str) -> tuple[Metadata, str]:
        """Extract frontmatter and content from text.

        Args:
            content: Text with optional YAML frontmatter

        Returns:
            Tuple of (metadata object, remaining content)
        """
        pattern = r'^---\s*\n(.*?)\n---\s*\n(.*)$'
        if match := re.match(pattern, content, re.DOTALL):
            try:
                yaml_data = safe_yaml_load(match[1], context="Frontmatter.extract")
                return Metadata(yaml_data or {}), match[2]
            except yaml.YAMLError:
                logger.warning("YAML Error in Frontmatter extraction.")
                return Metadata(), content
        return Metadata(), content

    @classmethod
    def extract_from_file(cls, file: Path) -> tuple[Metadata, str]:
        text_str = read_str_from_file(file)
        return cls.extract(text_str)

    @classmethod
    def embed(cls, metadata: Metadata, content: str) -> str:
        """Embed metadata as YAML frontmatter.

        Args:
            metadata: Dictionary of metadata
            content: Content text

        Returns:
            Text with embedded frontmatter
        """

        # Combine with content
        return (
            f"{cls.generate(metadata)}"
            f"{content.strip()}"
        )

    @staticmethod
    def generate(metadata: Metadata) -> str:
        if not metadata:
            return ""

        yaml_str = metadata.to_yaml() 
        return (
            f"---\n"
            f"{yaml_str}---\n\n"
        )
embed(metadata, content) classmethod

Embed metadata as YAML frontmatter.

Parameters:

Name Type Description Default
metadata Metadata

Dictionary of metadata

required
content str

Content text

required

Returns:

Type Description
str

Text with embedded frontmatter

Source code in src/tnh_scholar/metadata/metadata.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
@classmethod
def embed(cls, metadata: Metadata, content: str) -> str:
    """Embed metadata as YAML frontmatter.

    Args:
        metadata: Dictionary of metadata
        content: Content text

    Returns:
        Text with embedded frontmatter
    """

    # Combine with content
    return (
        f"{cls.generate(metadata)}"
        f"{content.strip()}"
    )
extract(content) staticmethod

Extract frontmatter and content from text.

Parameters:

Name Type Description Default
content str

Text with optional YAML frontmatter

required

Returns:

Type Description
tuple[Metadata, str]

Tuple of (metadata object, remaining content)

Source code in src/tnh_scholar/metadata/metadata.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
@staticmethod
def extract(content: str) -> tuple[Metadata, str]:
    """Extract frontmatter and content from text.

    Args:
        content: Text with optional YAML frontmatter

    Returns:
        Tuple of (metadata object, remaining content)
    """
    pattern = r'^---\s*\n(.*?)\n---\s*\n(.*)$'
    if match := re.match(pattern, content, re.DOTALL):
        try:
            yaml_data = safe_yaml_load(match[1], context="Frontmatter.extract")
            return Metadata(yaml_data or {}), match[2]
        except yaml.YAMLError:
            logger.warning("YAML Error in Frontmatter extraction.")
            return Metadata(), content
    return Metadata(), content
extract_from_file(file) classmethod
Source code in src/tnh_scholar/metadata/metadata.py
261
262
263
264
@classmethod
def extract_from_file(cls, file: Path) -> tuple[Metadata, str]:
    text_str = read_str_from_file(file)
    return cls.extract(text_str)
generate(metadata) staticmethod
Source code in src/tnh_scholar/metadata/metadata.py
284
285
286
287
288
289
290
291
292
293
@staticmethod
def generate(metadata: Metadata) -> str:
    if not metadata:
        return ""

    yaml_str = metadata.to_yaml() 
    return (
        f"---\n"
        f"{yaml_str}---\n\n"
    )
Metadata

Bases: MutableMapping

Flexible metadata container that behaves like a dict while ensuring JSON serializability. Designed for AI processing pipelines where schema flexibility is prioritized over structure.

Source code in src/tnh_scholar/metadata/metadata.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
class Metadata(MutableMapping):
    """
    Flexible metadata container that behaves like a dict while ensuring
    JSON serializability. Designed for AI processing pipelines where schema
    flexibility is prioritized over structure.
    """
    # Type processors at class level
    _type_processors = {
        Path: lambda p: path_as_str(p),
        datetime: lambda d: d.isoformat(),
    }

    def __init__(
        self, 
        data: Optional[Union[Dict[str, Any], 'Metadata']] = None
        ) -> None:
        self._data: Dict[str, JsonValue] = {}
        if data is not None:
            raw_data = data._data if isinstance(data, Metadata) else data
            processed_data = {
                k: self._process_value(v) for k, v in raw_data.items()
            }
            self.update(processed_data)

    def _process_value(self, value: Any) -> JsonValue:
        """Convert input values to JSON-serializable format."""
        if isinstance(value, tuple(self._type_processors.keys())):
            for type_, processor in self._type_processors.items():
                if isinstance(value, type_):
                    return processor(value)
        if not isinstance(value, (str, int, float, bool, list, dict, type(None))):
            raise ValueError(
                f"Value {value} of type {type(value)} has no conversion to JsonValue.")
        return value

    # Core dict interface
    def __getitem__(self, key: str) -> JsonValue:
        return self._data[key]

    def __setitem__(self, key: str, value: Any) -> None:
        """Process and set value, ensuring JSON serializability."""
        self._data[key] = self._process_value(value)

    def __delitem__(self, key: str) -> None:
        del self._data[key]

    def __iter__(self) -> Iterator[str]:
        return iter(self._data)

    def __len__(self) -> int:
        return len(self._data)

    def __str__(self) -> str:
        return self.to_yaml()

    # Dict union operations (|, |=)
    def __or__(self, other: Union[Mapping[str, JsonValue], 'Metadata']) -> 'Metadata':
        if isinstance(other, (Metadata, Mapping)):
            other_dict = other._data if isinstance(other, Metadata) else other
            return Metadata(self._data | other_dict) # type: ignore
        return NotImplemented

    def __ror__(self, other: Mapping[str, JsonValue]) -> 'Metadata':
        if isinstance(other, Mapping):
            return Metadata(other | self._data) # type: ignore
        return NotImplemented

    def __ior__(self, other: Union[Mapping[str, JsonValue], 'Metadata']) -> 'Metadata':
        if isinstance(other, (Metadata, Mapping)):
            self._data |= (other._data if isinstance(other, Metadata) else other)
            return self
        return NotImplemented

    def __repr__(self) -> str:
        return f"Metadata({self._data})"

    @classmethod
    def __get_pydantic_core_schema__(
        cls,
        source_type: Any,
        handler: Callable[[Any], core_schema.CoreSchema],
    ) -> core_schema.CoreSchema:
        """Defines the Pydantic core schema for the `Metadata` class.

        This method allows Pydantic to validate `Metadata` objects as dictionaries.
        It handles both direct `Metadata` instances and dictionaries during validation,
        providing flexibility for data input.

        Args:
            source_type: The source type being validated.
            handler: A callable to handle schema generation for other types.

        Returns:
            A Pydantic core schema that validates either a Metadata instance
            (by converting it to a dictionary) or a standard dictionary.
        """
        return core_schema.union_schema(
            choices=[
                # Handle Metadata instances with serialization
                core_schema.is_instance_schema(
                    cls,
                    serialization=core_schema.plain_serializer_function_ser_schema(
                        lambda x: x.to_dict()  # Use our to_dict method
                    )
                ),
                # Handle dictionary input
                handler(dict),
            ],
        )

    # JSON serialization
    def to_dict(self) -> Dict[str, JsonValue]:
        """Convert to plain dict for JSON serialization."""
        return self._data.copy()

    @classmethod
    def from_dict(cls, data: Dict[str, JsonValue]) -> 'Metadata':
        """Create from a plain dict."""
        return cls(data)

    def copy(self) -> 'Metadata':
        """Create a deep copy of the metadata object."""
        return Metadata(deepcopy(self._data))

    @classmethod
    def from_fields(cls, data: dict, fields: list[str]) -> "Metadata":
        """Create a Metadata object by extracting specified fields from a dictionary.

        Args:
            data: Source dictionary
            fields: List of field names to extract

        Returns:
            New Metadata instance with only specified fields
        """
        filtered = {k: data.get(k) for k in fields if k in data}
        return cls(filtered)

    @classmethod
    def from_yaml(cls, yaml_str: str) -> 'Metadata':
        """Create Metadata instance from YAML string.

        Args:
            yaml_str: YAML formatted string

        Returns:
            New Metadata instance

        Raises:
            yaml.YAMLError: If YAML parsing fails
        """
        if not yaml_str.strip():
            return cls()

        data = safe_yaml_load(yaml_str, context="Metadata.from_yaml()")
        return cls(data) if isinstance(data, dict) else cls()

    def text_embed(self, content: str):
        return Frontmatter.embed(self, content)

    def add_process_info(self, process_metadata: 'ProcessMetadata') -> None:
        """Add process metadata to history."""
        history = self.get(TNH_METADATA_PROCESS_FIELD, [])
        if not isinstance(history, list):
            history = []
        history.append(process_metadata.to_dict())  # Store as dict for serialization
        self[TNH_METADATA_PROCESS_FIELD] = history

    @property
    def process_history(self) -> List[Dict[str, Any]]:
        """Access process history with proper typing."""
        return self.get(TNH_METADATA_PROCESS_FIELD, [])

    def to_yaml(self) -> str:
        """Return metadata as YAML formatted string"""
        return yaml.dump(
            self._data,
            default_flow_style=False,
            allow_unicode=True
        )
process_history property

Access process history with proper typing.

__delitem__(key)
Source code in src/tnh_scholar/metadata/metadata.py
81
82
def __delitem__(self, key: str) -> None:
    del self._data[key]
__get_pydantic_core_schema__(source_type, handler) classmethod

Defines the Pydantic core schema for the Metadata class.

This method allows Pydantic to validate Metadata objects as dictionaries. It handles both direct Metadata instances and dictionaries during validation, providing flexibility for data input.

Parameters:

Name Type Description Default
source_type Any

The source type being validated.

required
handler Callable[[Any], CoreSchema]

A callable to handle schema generation for other types.

required

Returns:

Type Description
CoreSchema

A Pydantic core schema that validates either a Metadata instance

CoreSchema

(by converting it to a dictionary) or a standard dictionary.

Source code in src/tnh_scholar/metadata/metadata.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@classmethod
def __get_pydantic_core_schema__(
    cls,
    source_type: Any,
    handler: Callable[[Any], core_schema.CoreSchema],
) -> core_schema.CoreSchema:
    """Defines the Pydantic core schema for the `Metadata` class.

    This method allows Pydantic to validate `Metadata` objects as dictionaries.
    It handles both direct `Metadata` instances and dictionaries during validation,
    providing flexibility for data input.

    Args:
        source_type: The source type being validated.
        handler: A callable to handle schema generation for other types.

    Returns:
        A Pydantic core schema that validates either a Metadata instance
        (by converting it to a dictionary) or a standard dictionary.
    """
    return core_schema.union_schema(
        choices=[
            # Handle Metadata instances with serialization
            core_schema.is_instance_schema(
                cls,
                serialization=core_schema.plain_serializer_function_ser_schema(
                    lambda x: x.to_dict()  # Use our to_dict method
                )
            ),
            # Handle dictionary input
            handler(dict),
        ],
    )
__getitem__(key)
Source code in src/tnh_scholar/metadata/metadata.py
74
75
def __getitem__(self, key: str) -> JsonValue:
    return self._data[key]
__init__(data=None)
Source code in src/tnh_scholar/metadata/metadata.py
50
51
52
53
54
55
56
57
58
59
60
def __init__(
    self, 
    data: Optional[Union[Dict[str, Any], 'Metadata']] = None
    ) -> None:
    self._data: Dict[str, JsonValue] = {}
    if data is not None:
        raw_data = data._data if isinstance(data, Metadata) else data
        processed_data = {
            k: self._process_value(v) for k, v in raw_data.items()
        }
        self.update(processed_data)
__ior__(other)
Source code in src/tnh_scholar/metadata/metadata.py
105
106
107
108
109
def __ior__(self, other: Union[Mapping[str, JsonValue], 'Metadata']) -> 'Metadata':
    if isinstance(other, (Metadata, Mapping)):
        self._data |= (other._data if isinstance(other, Metadata) else other)
        return self
    return NotImplemented
__iter__()
Source code in src/tnh_scholar/metadata/metadata.py
84
85
def __iter__(self) -> Iterator[str]:
    return iter(self._data)
__len__()
Source code in src/tnh_scholar/metadata/metadata.py
87
88
def __len__(self) -> int:
    return len(self._data)
__or__(other)
Source code in src/tnh_scholar/metadata/metadata.py
94
95
96
97
98
def __or__(self, other: Union[Mapping[str, JsonValue], 'Metadata']) -> 'Metadata':
    if isinstance(other, (Metadata, Mapping)):
        other_dict = other._data if isinstance(other, Metadata) else other
        return Metadata(self._data | other_dict) # type: ignore
    return NotImplemented
__repr__()
Source code in src/tnh_scholar/metadata/metadata.py
111
112
def __repr__(self) -> str:
    return f"Metadata({self._data})"
__ror__(other)
Source code in src/tnh_scholar/metadata/metadata.py
100
101
102
103
def __ror__(self, other: Mapping[str, JsonValue]) -> 'Metadata':
    if isinstance(other, Mapping):
        return Metadata(other | self._data) # type: ignore
    return NotImplemented
__setitem__(key, value)

Process and set value, ensuring JSON serializability.

Source code in src/tnh_scholar/metadata/metadata.py
77
78
79
def __setitem__(self, key: str, value: Any) -> None:
    """Process and set value, ensuring JSON serializability."""
    self._data[key] = self._process_value(value)
__str__()
Source code in src/tnh_scholar/metadata/metadata.py
90
91
def __str__(self) -> str:
    return self.to_yaml()
add_process_info(process_metadata)

Add process metadata to history.

Source code in src/tnh_scholar/metadata/metadata.py
198
199
200
201
202
203
204
def add_process_info(self, process_metadata: 'ProcessMetadata') -> None:
    """Add process metadata to history."""
    history = self.get(TNH_METADATA_PROCESS_FIELD, [])
    if not isinstance(history, list):
        history = []
    history.append(process_metadata.to_dict())  # Store as dict for serialization
    self[TNH_METADATA_PROCESS_FIELD] = history
copy()

Create a deep copy of the metadata object.

Source code in src/tnh_scholar/metadata/metadata.py
158
159
160
def copy(self) -> 'Metadata':
    """Create a deep copy of the metadata object."""
    return Metadata(deepcopy(self._data))
from_dict(data) classmethod

Create from a plain dict.

Source code in src/tnh_scholar/metadata/metadata.py
153
154
155
156
@classmethod
def from_dict(cls, data: Dict[str, JsonValue]) -> 'Metadata':
    """Create from a plain dict."""
    return cls(data)
from_fields(data, fields) classmethod

Create a Metadata object by extracting specified fields from a dictionary.

Parameters:

Name Type Description Default
data dict

Source dictionary

required
fields list[str]

List of field names to extract

required

Returns:

Type Description
Metadata

New Metadata instance with only specified fields

Source code in src/tnh_scholar/metadata/metadata.py
162
163
164
165
166
167
168
169
170
171
172
173
174
@classmethod
def from_fields(cls, data: dict, fields: list[str]) -> "Metadata":
    """Create a Metadata object by extracting specified fields from a dictionary.

    Args:
        data: Source dictionary
        fields: List of field names to extract

    Returns:
        New Metadata instance with only specified fields
    """
    filtered = {k: data.get(k) for k in fields if k in data}
    return cls(filtered)
from_yaml(yaml_str) classmethod

Create Metadata instance from YAML string.

Parameters:

Name Type Description Default
yaml_str str

YAML formatted string

required

Returns:

Type Description
Metadata

New Metadata instance

Raises:

Type Description
YAMLError

If YAML parsing fails

Source code in src/tnh_scholar/metadata/metadata.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
@classmethod
def from_yaml(cls, yaml_str: str) -> 'Metadata':
    """Create Metadata instance from YAML string.

    Args:
        yaml_str: YAML formatted string

    Returns:
        New Metadata instance

    Raises:
        yaml.YAMLError: If YAML parsing fails
    """
    if not yaml_str.strip():
        return cls()

    data = safe_yaml_load(yaml_str, context="Metadata.from_yaml()")
    return cls(data) if isinstance(data, dict) else cls()
text_embed(content)
Source code in src/tnh_scholar/metadata/metadata.py
195
196
def text_embed(self, content: str):
    return Frontmatter.embed(self, content)
to_dict()

Convert to plain dict for JSON serialization.

Source code in src/tnh_scholar/metadata/metadata.py
149
150
151
def to_dict(self) -> Dict[str, JsonValue]:
    """Convert to plain dict for JSON serialization."""
    return self._data.copy()
to_yaml()

Return metadata as YAML formatted string

Source code in src/tnh_scholar/metadata/metadata.py
211
212
213
214
215
216
217
def to_yaml(self) -> str:
    """Return metadata as YAML formatted string"""
    return yaml.dump(
        self._data,
        default_flow_style=False,
        allow_unicode=True
    )
ProcessMetadata

Bases: Metadata

Records information about a specific processing operation.

Source code in src/tnh_scholar/metadata/metadata.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
class ProcessMetadata(Metadata):
    """Records information about a specific processing operation."""
    def __init__(
        self,
        step: str,
        processor: str, 
        tool: Optional[str] = None,
        **additional_params
    ):
        # Initialize base Metadata with our process data structure
        super().__init__({
            "step": step,
            "timestamp": datetime.now(),
            "processor": processor,
            "tool": tool,
        })

        # Add any additional parameters at top level
        self.update(additional_params)
__init__(step, processor, tool=None, **additional_params)
Source code in src/tnh_scholar/metadata/metadata.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def __init__(
    self,
    step: str,
    processor: str, 
    tool: Optional[str] = None,
    **additional_params
):
    # Initialize base Metadata with our process data structure
    super().__init__({
        "step": step,
        "timestamp": datetime.now(),
        "processor": processor,
        "tool": tool,
    })

    # Add any additional parameters at top level
    self.update(additional_params)
safe_yaml_load(yaml_str, *, context='unknown')
Source code in src/tnh_scholar/metadata/metadata.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def safe_yaml_load(yaml_str: str, *, context: str = "unknown") -> dict:
    try:
        data = yaml.safe_load(yaml_str)
        if not isinstance(data, dict):
            logger.warning(
                "YAML in [%s] is not a dict. Returning empty metadata.", context
                )
            return {}
        return data
    except ScannerError as e:
        snippet = yaml_str.replace("\n", "\\n")
        logger.error("YAML ScannerError in [%s]: %s\nSnippet:\n%s", context, e, snippet)
    except yaml.YAMLError as e:
        logger.error("General YAML error in [%s]: %s", context, e)
    return {}

ocr_processing

DEFAULT_ANNOTATION_FONT_PATH = Path('/System/Library/Fonts/Supplemental/Arial.ttf') module-attribute

DEFAULT_ANNOTATION_FONT_SIZE = 12 module-attribute

DEFAULT_ANNOTATION_LANGUAGE_HINTS = ['vi'] module-attribute

DEFAULT_ANNOTATION_METHOD = 'DOCUMENT_TEXT_DETECTION' module-attribute

DEFAULT_ANNOTATION_OFFSET = 2 module-attribute

logger = logging.getLogger('ocr_processing') module-attribute

PDFParseWarning

Bases: Warning

Custom warning class for PDF parsing issues. Encapsulates minimal logic for displaying warnings with a custom format.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class PDFParseWarning(Warning):
    """
    Custom warning class for PDF parsing issues.
    Encapsulates minimal logic for displaying warnings with a custom format.
    """

    @staticmethod
    def warn(message: str):
        """
        Display a warning message with custom formatting.

        Parameters:
            message (str): The warning message to display.
        """
        formatted_message = f"\033[93mPDFParseWarning: {message}\033[0m"
        print(formatted_message)  # Simply prints the warning
warn(message) staticmethod

Display a warning message with custom formatting.

Parameters:

Name Type Description Default
message str

The warning message to display.

required
Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
30
31
32
33
34
35
36
37
38
39
@staticmethod
def warn(message: str):
    """
    Display a warning message with custom formatting.

    Parameters:
        message (str): The warning message to display.
    """
    formatted_message = f"\033[93mPDFParseWarning: {message}\033[0m"
    print(formatted_message)  # Simply prints the warning

annotate_image_with_text(image, text_annotations, annotation_font_path, font_size=12)

Annotates a PIL image with bounding boxes and text descriptions from OCR results.

Parameters:

Name Type Description Default
pil_image Image

The input PIL image to annotate.

required
text_annotations List[EntityAnnotation]

OCR results containing bounding boxes and text.

required
annotation_font_path str

Path to the font file for text annotations.

required
font_size int

Font size for text annotations.

12

Returns:

Type Description
Image

Image.Image: The annotated PIL image.

Raises:

Type Description
ValueError

If the input image is None.

IOError

If the font file cannot be loaded.

Exception

For any other unexpected errors.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def annotate_image_with_text(
    image: Image.Image,
    text_annotations: List[EntityAnnotation],
    annotation_font_path: str,
    font_size: int = 12,
) -> Image.Image:
    """
    Annotates a PIL image with bounding boxes and text descriptions from OCR results.

    Parameters:
        pil_image (Image.Image): The input PIL image to annotate.
        text_annotations (List[EntityAnnotation]): OCR results containing bounding boxes and text.
        annotation_font_path (str): Path to the font file for text annotations.
        font_size (int): Font size for text annotations.

    Returns:
        Image.Image: The annotated PIL image.

    Raises:
        ValueError: If the input image is None.
        IOError: If the font file cannot be loaded.
        Exception: For any other unexpected errors.
    """
    if image is None:
        raise ValueError("The input image is None.")

    try:
        font = ImageFont.truetype(annotation_font_path, font_size)
    except IOError as e:
        raise IOError(f"Failed to load the font from '{annotation_font_path}': {e}")

    draw = ImageDraw.Draw(image)

    try:
        for i, text_obj in enumerate(text_annotations):
            vertices = [
                (vertex.x, vertex.y) for vertex in text_obj.bounding_poly.vertices
            ]
            if (
                len(vertices) == 4
            ):  # Ensure there are exactly 4 vertices for a rectangle
                # Draw the bounding box
                draw.polygon(vertices, outline="red", width=2)

                # Skip the first bounding box (whole text region)
                if i > 0:
                    # Offset the text position slightly for clarity
                    text_position = (vertices[0][0] + 2, vertices[0][1] + 2)
                    draw.text(
                        text_position, text_obj.description, fill="red", font=font
                    )

    except AttributeError as e:
        raise ValueError(f"Invalid text annotation structure: {e}")
    except Exception as e:
        raise Exception(f"An error occurred during image annotation: {e}")

    return image

build_processed_pdf(pdf_path, client, preprocessor=None, annotation_font_path=DEFAULT_ANNOTATION_FONT_PATH)

Processes a PDF document, extracting text, word locations, annotated images, and unannotated images.

Parameters:

Name Type Description Default
pdf_path Path

Path to the PDF file.

required
client ImageAnnotatorClient

Google Vision API client for text detection.

required
annotation_font_path Path

Path to the font file for annotations.

DEFAULT_ANNOTATION_FONT_PATH

Returns:

Type Description
Tuple[List[str], List[List[EntityAnnotation]], List[Image], List[Image]]

Tuple[List[str], List[List[vision.EntityAnnotation]], List[Image.Image], List[Image.Image]]: - List of extracted full-page texts (one entry per page). - List of word locations (list of vision.EntityAnnotation objects for each page). - List of annotated images (with bounding boxes and text annotations). - List of unannotated images (raw page images).

Raises:

Type Description
FileNotFoundError

If the specified PDF file does not exist.

ValueError

If the PDF file is invalid or contains no pages.

Exception

For any unexpected errors during processing.

Example

from pathlib import Path from google.cloud import vision pdf_path = Path("/path/to/example.pdf") font_path = Path("/path/to/fonts/Arial.ttf") client = vision.ImageAnnotatorClient() try: text_pages, word_locations_list, annotated_images, unannotated_images = build_processed_pdf( pdf_path, client, font_path ) print(f"Processed {len(text_pages)} pages successfully!") except Exception as e: print(f"Error processing PDF: {e}")

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
def build_processed_pdf(
    pdf_path: Path,
    client: vision.ImageAnnotatorClient,
    preprocessor: Callable = None,
    annotation_font_path: Path = DEFAULT_ANNOTATION_FONT_PATH,
) -> Tuple[
    List[str], List[List[vision.EntityAnnotation]], List[Image.Image], List[Image.Image]
]:
    """
    Processes a PDF document, extracting text, word locations, annotated images, and unannotated images.

    Parameters:
        pdf_path (Path): Path to the PDF file.
        client (vision.ImageAnnotatorClient): Google Vision API client for text detection.
        annotation_font_path (Path): Path to the font file for annotations.

    Returns:
        Tuple[List[str], List[List[vision.EntityAnnotation]], List[Image.Image], List[Image.Image]]:
            - List of extracted full-page texts (one entry per page).
            - List of word locations (list of `vision.EntityAnnotation` objects for each page).
            - List of annotated images (with bounding boxes and text annotations).
            - List of unannotated images (raw page images).

    Raises:
        FileNotFoundError: If the specified PDF file does not exist.
        ValueError: If the PDF file is invalid or contains no pages.
        Exception: For any unexpected errors during processing.

    Example:
        >>> from pathlib import Path
        >>> from google.cloud import vision
        >>> pdf_path = Path("/path/to/example.pdf")
        >>> font_path = Path("/path/to/fonts/Arial.ttf")
        >>> client = vision.ImageAnnotatorClient()
        >>> try:
        >>>     text_pages, word_locations_list, annotated_images, unannotated_images = build_processed_pdf(
        >>>         pdf_path, client, font_path
        >>>     )
        >>>     print(f"Processed {len(text_pages)} pages successfully!")
        >>> except Exception as e:
        >>>     print(f"Error processing PDF: {e}")
    """
    try:
        doc = load_pdf_pages(pdf_path)
    except FileNotFoundError as fnf_error:
        raise FileNotFoundError(f"Error loading PDF: {fnf_error}")
    except ValueError as ve:
        raise ValueError(f"Invalid PDF file: {ve}")
    except Exception as e:
        raise Exception(f"An unexpected error occurred while loading the PDF: {e}")

    if doc.page_count == 0:
        raise ValueError(f"The PDF file '{pdf_path}' contains no pages.")

    logger.info(f"Processing file with {doc.page_count} pages:\n\t{pdf_path}")

    text_pages = []
    word_locations_list = []
    annotated_images = []
    unannotated_images = []
    first_page_dimensions = None

    for page_num in range(doc.page_count):
        logger.info(f"Processing page {page_num + 1}/{doc.page_count}...")

        try:
            page = doc.load_page(page_num)
            (
                full_page_text,
                word_locations,
                annotated_image,
                unannotated_image,
                page_dimensions,
            ) = process_page(page, client, annotation_font_path, preprocessor)

            if full_page_text:  # this is not an empty page

                if page_num == 0:  # save first page info
                    first_page_dimensions = page_dimensions
                elif (
                    page_dimensions != first_page_dimensions
                ):  # verify page dimensions are consistent
                    PDFParseWarning.warn(
                        f"Page {page_num + 1} has different dimensions than page 1."
                        f"({page_dimensions}) compared to the first page: ({first_page_dimensions})."
                    )

                text_pages.append(full_page_text)
                word_locations_list.append(word_locations)
                annotated_images.append(annotated_image)
                unannotated_images.append(unannotated_image)
            else:
                PDFParseWarning.warn(
                    f"Page {page_num + 1} empty, added empty datastructures...\n"
                    # f"  (Note that total document length will be reduced.)"
                )

        except ValueError as ve:
            print(f"ValueError on page {page_num + 1}: {ve}")
        except OSError as oe:
            print(f"OSError on page {page_num + 1}: {oe}")
        except Exception as e:
            print(f"Unexpected error on page {page_num + 1}: {e}")

    print(f"page dimensions: {page_dimensions}")
    return text_pages, word_locations_list, annotated_images, unannotated_images

deserialize_entity_annotations_from_json(data)

Deserializes JSON data into a nested list of EntityAnnotation objects.

Parameters:

Name Type Description Default
data str

The JSON string containing serialized annotations.

required

Returns:

Type Description
List[List[EntityAnnotation]]

List[List[EntityAnnotation]]: The reconstructed nested list of EntityAnnotation objects.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
def deserialize_entity_annotations_from_json(data: str) -> List[List[EntityAnnotation]]:
    """
    Deserializes JSON data into a nested list of EntityAnnotation objects.

    Parameters:
        data (str): The JSON string containing serialized annotations.

    Returns:
        List[List[EntityAnnotation]]: The reconstructed nested list of EntityAnnotation objects.
    """
    serialized_data = json.loads(data)
    deserialized_data = []

    for serialized_page in serialized_data:
        page_annotations = [
            EntityAnnotation.deserialize(base64.b64decode(serialized_annotation))
            for serialized_annotation in serialized_page
        ]
        deserialized_data.append(page_annotations)

    return deserialized_data

extract_image_from_page(page)

Extracts the first image from the given PDF page and returns it as a PIL Image.

Parameters:

Name Type Description Default
page Page

The PDF page object.

required

Returns:

Type Description
Image

Image.Image: The first image on the page as a Pillow Image object.

Raises:

Type Description
ValueError

If no images are found on the page or the image data is incomplete.

Exception

For unexpected errors during image extraction.

Example

import fitz from PIL import Image doc = fitz.open("/path/to/document.pdf") page = doc.load_page(0) # Load the first page try: image = extract_image_from_page(page) image.show() # Display the image except Exception as e: print(f"Error extracting image: {e}")

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def extract_image_from_page(page: fitz.Page) -> Image.Image:
    """
    Extracts the first image from the given PDF page and returns it as a PIL Image.

    Parameters:
        page (fitz.Page): The PDF page object.

    Returns:
        Image.Image: The first image on the page as a Pillow Image object.

    Raises:
        ValueError: If no images are found on the page or the image data is incomplete.
        Exception: For unexpected errors during image extraction.

    Example:
        >>> import fitz
        >>> from PIL import Image
        >>> doc = fitz.open("/path/to/document.pdf")
        >>> page = doc.load_page(0)  # Load the first page
        >>> try:
        >>>     image = extract_image_from_page(page)
        >>>     image.show()  # Display the image
        >>> except Exception as e:
        >>>     print(f"Error extracting image: {e}")
    """
    try:
        # Get images from the page
        images = page.get_images(full=True)
        if not images:
            raise ValueError("No images found on the page.")

        # Extract the first image reference
        xref = images[0][0]  # Get the first image's xref
        base_image = page.parent.extract_image(xref)

        # Validate the extracted image data
        if (
            "image" not in base_image
            or "width" not in base_image
            or "height" not in base_image
        ):
            raise ValueError("The extracted image data is incomplete.")

        # Convert the raw image bytes into a Pillow image
        image_bytes = base_image["image"]
        pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

        return pil_image

    except ValueError as ve:
        raise ve  # Re-raise for calling functions to handle
    except Exception as e:
        raise Exception(f"An unexpected error occurred during image extraction: {e}")

get_page_dimensions(page)

Extracts the width and height of a single PDF page in both inches and pixels.

Parameters:

Name Type Description Default
page Page

A single PDF page object from PyMuPDF.

required

Returns:

Name Type Description
dict dict

A dictionary containing the width and height of the page in inches and pixels.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def get_page_dimensions(page: fitz.Page) -> dict:
    """
    Extracts the width and height of a single PDF page in both inches and pixels.

    Args:
        page (fitz.Page): A single PDF page object from PyMuPDF.

    Returns:
        dict: A dictionary containing the width and height of the page in inches and pixels.
    """
    # Get page dimensions in points and convert to inches
    page_width_pts, page_height_pts = page.rect.width, page.rect.height
    page_width_in = page_width_pts / 72  # Convert points to inches
    page_height_in = page_height_pts / 72

    # Extract the first image on the page (if any) to get pixel dimensions
    images = page.get_images(full=True)
    if images:
        xref = images[0][0]
        base_image = page.parent.extract_image(xref)
        width_px = base_image["width"]
        height_px = base_image["height"]
    else:
        width_px, height_px = None, None  # No image found on the page

    # Return dimensions
    return {
        "width_in": page_width_in,
        "height_in": page_height_in,
        "width_px": width_px,
        "height_px": height_px,
    }

load_pdf_pages(pdf_path)

Opens the PDF document and returns the fitz Document object.

Parameters:

Name Type Description Default
pdf_path Path

The path to the PDF file.

required

Returns:

Type Description
Document

fitz.Document: The loaded PDF document.

Raises:

Type Description
FileNotFoundError

If the specified file does not exist.

ValueError

If the file is not a valid PDF document.

Exception

For any unexpected error.

Example

from pathlib import Path pdf_path = Path("/path/to/example.pdf") try: pdf_doc = load_pdf_pages(pdf_path) print(f"PDF contains {pdf_doc.page_count} pages.") except Exception as e: print(f"Error loading PDF: {e}")

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def load_pdf_pages(pdf_path: Path) -> fitz.Document:
    """
    Opens the PDF document and returns the fitz Document object.

    Parameters:
        pdf_path (Path): The path to the PDF file.

    Returns:
        fitz.Document: The loaded PDF document.

    Raises:
        FileNotFoundError: If the specified file does not exist.
        ValueError: If the file is not a valid PDF document.
        Exception: For any unexpected error.

    Example:
        >>> from pathlib import Path
        >>> pdf_path = Path("/path/to/example.pdf")
        >>> try:
        >>>     pdf_doc = load_pdf_pages(pdf_path)
        >>>     print(f"PDF contains {pdf_doc.page_count} pages.")
        >>> except Exception as e:
        >>>     print(f"Error loading PDF: {e}")
    """
    if not pdf_path.exists():
        raise FileNotFoundError(f"The file '{pdf_path}' does not exist.")

    if not pdf_path.suffix.lower() == ".pdf":
        raise ValueError(
            f"The file '{pdf_path}' is not a valid PDF document (expected '.pdf')."
        )

    try:
        return fitz.open(str(pdf_path))  # PyMuPDF expects a string path
    except Exception as e:
        raise Exception(f"An unexpected error occurred while opening the PDF: {e}")

load_processed_PDF_data(base_path)

Loads processed PDF data from files using metadata for file references.

Parameters:

Name Type Description Default
output_dir Path

Directory where the data is stored (as a Path object).

required
base_name str

Base name of the processed directory.

required

Returns:

Type Description
Tuple[List[str], List[List[EntityAnnotation]], List[Image], List[Image]]

Tuple[List[str], List[List[EntityAnnotation]], List[Image.Image], List[Image.Image]]: - Loaded text pages. - Word locations (list of EntityAnnotation objects for each page). - Annotated images. - Unannotated images.

Raises:

Type Description
FileNotFoundError

If any required files are missing.

ValueError

If the metadata file is incomplete or invalid.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
def load_processed_PDF_data(
    base_path: Path,
) -> Tuple[
    List[str], List[List[EntityAnnotation]], List[Image.Image], List[Image.Image]
]:
    """
    Loads processed PDF data from files using metadata for file references.

    Parameters:
        output_dir (Path): Directory where the data is stored (as a Path object).
        base_name (str): Base name of the processed directory.

    Returns:
        Tuple[List[str], List[List[EntityAnnotation]], List[Image.Image], List[Image.Image]]:
            - Loaded text pages.
            - Word locations (list of `EntityAnnotation` objects for each page).
            - Annotated images.
            - Unannotated images.

    Raises:
        FileNotFoundError: If any required files are missing.
        ValueError: If the metadata file is incomplete or invalid.
    """
    metadata_file = base_path / "metadata.json"

    # Load metadata
    try:
        with metadata_file.open("r", encoding="utf-8") as f:
            metadata = json.load(f)
    except FileNotFoundError:
        raise FileNotFoundError(f"Metadata file '{metadata_file}' not found.")
    except json.JSONDecodeError as e:
        raise ValueError(f"Invalid metadata file format: {e}")

    # Extract file paths from metadata
    text_pages_file = base_path / metadata.get("files", {}).get(
        "text_pages", "text_pages.json"
    )
    word_locations_file = base_path / metadata.get("files", {}).get(
        "word_locations", "word_locations.json"
    )
    images_dir = Path(metadata.get("images_directory", base_path / "images"))

    # Validate file paths
    if not text_pages_file.exists():
        raise FileNotFoundError(f"Text pages file '{text_pages_file}' not found.")
    if not word_locations_file.exists():
        raise FileNotFoundError(
            f"Word locations file '{word_locations_file}' not found."
        )
    if not images_dir.exists() or not images_dir.is_dir():
        raise FileNotFoundError(f"Images directory '{images_dir}' not found.")

    # Load text pages
    with text_pages_file.open("r", encoding="utf-8") as f:
        text_pages = json.load(f)

    # Load word locations
    with word_locations_file.open("r", encoding="utf-8") as f:
        serialized_word_locations = f.read()
        word_locations = deserialize_entity_annotations_from_json(
            serialized_word_locations
        )

    # Load images
    annotated_images = []
    unannotated_images = []
    for file in sorted(
        images_dir.iterdir()
    ):  # Iterate over files in the images directory
        if file.name.startswith("annotated_page_") and file.suffix == ".png":
            annotated_images.append(Image.open(file))
        elif file.name.startswith("unannotated_page_") and file.suffix == ".png":
            unannotated_images.append(Image.open(file))

    # Ensure images were loaded correctly
    if not annotated_images or not unannotated_images:
        raise ValueError(f"No images found in the directory '{images_dir}'.")

    return text_pages, word_locations, annotated_images, unannotated_images

make_image_preprocess_mask(mask_height)

Creates a preprocessing function that masks a specified height at the bottom of the image.

Parameters:

Name Type Description Default
mask_height float

The proportion of the image height to mask at the bottom (0.0 to 1.0).

required

Returns:

Type Description
Callable[[Image, int], Image]

Callable[[Image.Image, int], Image.Image]: A preprocessing function that takes an image

Callable[[Image, int], Image]

and page number as input and returns the processed image.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def make_image_preprocess_mask(
    mask_height: float,
) -> Callable[[Image.Image, int], Image.Image]:
    """
    Creates a preprocessing function that masks a specified height at the bottom of the image.

    Parameters:
        mask_height (float): The proportion of the image height to mask at the bottom (0.0 to 1.0).

    Returns:
        Callable[[Image.Image, int], Image.Image]: A preprocessing function that takes an image
        and page number as input and returns the processed image.
    """

    def pre_process_image(image: Image.Image, page_number: int) -> Image.Image:
        """
        Preprocesses the image by masking the bottom region or performing other preprocessing steps.

        Parameters:
            image (Image.Image): The input image as a Pillow object.
            page_number (int): The page number of the image (useful for conditional preprocessing).

        Returns:
            Image.Image: The preprocessed image.
        """

        if page_number > 0:  # don't apply mask to cover page.
            # Make a copy of the image to avoid modifying the original
            draw = ImageDraw.Draw(image)

            # Get image dimensions
            width, height = image.size

            # Mask the bottom region based on the specified height proportion
            mask_pixels = int(height * mask_height)
            draw.rectangle([(0, height - mask_pixels), (width, height)], fill="black")

        return image

    return pre_process_image

pil_to_bytes(image, format='PNG')

Converts a Pillow image to raw bytes.

Parameters:

Name Type Description Default
image Image

The Pillow image object to convert.

required
format str

The format to save the image as (e.g., "PNG", "JPEG"). Default is "PNG".

'PNG'

Returns:

Name Type Description
bytes bytes

The raw bytes of the image.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def pil_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
    """
    Converts a Pillow image to raw bytes.

    Parameters:
        image (Image.Image): The Pillow image object to convert.
        format (str): The format to save the image as (e.g., "PNG", "JPEG"). Default is "PNG".

    Returns:
        bytes: The raw bytes of the image.
    """
    with io.BytesIO() as output:
        image.save(output, format=format)
        return output.getvalue()

process_page(page, client, annotation_font_path, preprocessor=None)

Processes a single PDF page, extracting text, word locations, and annotated images.

Parameters:

Name Type Description Default
page Page

The PDF page object.

required
client ImageAnnotatorClient

Google Vision API client for text detection.

required
pre_processor Callable[[Image, int], Image]

Preprocessing function for the image.

required
annotation_font_path str

Path to the font file for annotations.

required

Returns:

Type Description
Tuple[str, List[EntityAnnotation], Image, Image, dict]

Tuple[str, List[vision.EntityAnnotation], Image.Image, Image.Image, dict]: - Full page text (str) - Word locations (List of vision.EntityAnnotation) - Annotated image (Pillow Image object) - Original unprocessed image (Pillow Image object) - Page dimensions (dict)

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
def process_page(
    page: fitz.Page,
    client: vision.ImageAnnotatorClient,
    annotation_font_path: str,
    preprocessor: Callable[[Image.Image, int], Image.Image] = None,
) -> Tuple[str, List[vision.EntityAnnotation], Image.Image, Image.Image, dict]:
    """
    Processes a single PDF page, extracting text, word locations, and annotated images.

    Parameters:
        page (fitz.Page): The PDF page object.
        client (vision.ImageAnnotatorClient): Google Vision API client for text detection.
        pre_processor (Callable[[Image.Image, int], Image.Image]): Preprocessing function for the image.
        annotation_font_path (str): Path to the font file for annotations.

    Returns:
        Tuple[str, List[vision.EntityAnnotation], Image.Image, Image.Image, dict]:
            - Full page text (str)
            - Word locations (List of vision.EntityAnnotation)
            - Annotated image (Pillow Image object)
            - Original unprocessed image (Pillow Image object)
            - Page dimensions (dict)
    """
    # Extract the original image from the PDF page
    original_image = extract_image_from_page(page)

    # Make a copy of the original image for processing
    processed_image = original_image.copy()

    # Apply the preprocessing function (if provided)
    if preprocessor:
        # print("preprocessing...") # debug
        processed_image = preprocessor(processed_image, page.number)
        # processed_image.show() # debug

    # Annotate the processed image using the Vision API
    response = process_single_image(processed_image, client)

    if response:
        text_annotations = response.text_annotations
        # Extract full text and word locations
        full_page_text = text_annotations[0].description if text_annotations else ""
        word_locations = text_annotations[1:] if len(text_annotations) > 1 else []
    else:
        # return empty data
        full_page_text = ""
        word_locations = [EntityAnnotation()]
        text_annotations = [
            EntityAnnotation()
        ]  # create empty data structures to allow storing to proceed.

    # Create an annotated image with bounding boxes and labels
    annotated_image = annotate_image_with_text(
        processed_image, text_annotations, annotation_font_path
    )

    # Get page dimensions (from the original PDF page, not the image)
    page_dimensions = get_page_dimensions(page)

    return (
        full_page_text,
        word_locations,
        annotated_image,
        original_image,
        page_dimensions,
    )

process_single_image(image, client, feature_type=DEFAULT_ANNOTATION_METHOD, language_hints=DEFAULT_ANNOTATION_LANGUAGE_HINTS)

Processes a single image with the Google Vision API and returns text annotations.

Parameters:

Name Type Description Default
image Image

The preprocessed Pillow image object.

required
client ImageAnnotatorClient

Google Vision API client for text detection.

required
feature_type str

Type of text detection to use ('TEXT_DETECTION' or 'DOCUMENT_TEXT_DETECTION').

DEFAULT_ANNOTATION_METHOD
language_hints List

Language hints for OCR.

DEFAULT_ANNOTATION_LANGUAGE_HINTS

Returns:

Type Description
List[EntityAnnotation]

List[vision.EntityAnnotation]: Text annotations from the Vision API response.

Raises:

Type Description
ValueError

If no text is detected.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
def process_single_image(
    image: Image.Image,
    client: vision.ImageAnnotatorClient,
    feature_type: str = DEFAULT_ANNOTATION_METHOD,
    language_hints: List = DEFAULT_ANNOTATION_LANGUAGE_HINTS,
) -> List[vision.EntityAnnotation]:
    """
    Processes a single image with the Google Vision API and returns text annotations.

    Parameters:
        image (Image.Image): The preprocessed Pillow image object.
        client (vision.ImageAnnotatorClient): Google Vision API client for text detection.
        feature_type (str): Type of text detection to use ('TEXT_DETECTION' or 'DOCUMENT_TEXT_DETECTION').
        language_hints (List): Language hints for OCR.

    Returns:
        List[vision.EntityAnnotation]: Text annotations from the Vision API response.

    Raises:
        ValueError: If no text is detected.
    """
    # Convert the Pillow image to bytes
    image_bytes = pil_to_bytes(image, format="PNG")

    # Map feature type
    feature_map = {
        "TEXT_DETECTION": vision.Feature.Type.TEXT_DETECTION,
        "DOCUMENT_TEXT_DETECTION": vision.Feature.Type.DOCUMENT_TEXT_DETECTION,
    }
    if feature_type not in feature_map:
        raise ValueError(
            f"Invalid feature type '{feature_type}'. Use 'TEXT_DETECTION' or 'DOCUMENT_TEXT_DETECTION'."
        )

    # Prepare Vision API request
    vision_image = vision.Image(content=image_bytes)
    features = [vision.Feature(type=feature_map[feature_type])]
    image_context = vision.ImageContext(language_hints=language_hints)

    # Make the API call
    response = client.annotate_image(
        {"image": vision_image, "features": features, "image_context": image_context}
    )

    return response

save_processed_pdf_data(output_dir, journal_name, text_pages, word_locations, annotated_images, unannotated_images)

Saves processed PDF data to files for later reloading.

Parameters:

Name Type Description Default
output_dir Path

Directory to save the data (as a Path object).

required
base_name str

Base name for the output directory (usually the PDF name without extension).

required
text_pages List[str]

Extracted full-page text.

required
word_locations List[List[EntityAnnotation]]

Word locations and annotations from Vision API.

required
annotated_images List[Image]

Annotated images with bounding boxes.

required
unannotated_images List[Image]

Raw unannotated images.

required

Returns:

Type Description

None

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
def save_processed_pdf_data(
    output_dir: Path,
    journal_name: str,
    text_pages: List[str],
    word_locations: List[List[EntityAnnotation]],
    annotated_images: List[Image.Image],
    unannotated_images: List[Image.Image],
):
    """
    Saves processed PDF data to files for later reloading.

    Parameters:
        output_dir (Path): Directory to save the data (as a Path object).
        base_name (str): Base name for the output directory (usually the PDF name without extension).
        text_pages (List[str]): Extracted full-page text.
        word_locations (List[List[EntityAnnotation]]): Word locations and annotations from Vision API.
        annotated_images (List[PIL.Image.Image]): Annotated images with bounding boxes.
        unannotated_images (List[PIL.Image.Image]): Raw unannotated images.

    Returns:
        None
    """
    # Create output directories
    base_path = output_dir / journal_name / "ocr_data"
    images_dir = base_path / "images"

    base_path.mkdir(parents=True, exist_ok=True)
    images_dir.mkdir(parents=True, exist_ok=True)

    # Save text data
    text_pages_file = base_path / "text_pages.json"
    with text_pages_file.open("w", encoding="utf-8") as f:
        json.dump(text_pages, f, indent=4, ensure_ascii=False)

    # Save word locations as JSON
    word_locations_file = base_path / "word_locations.json"
    serialized_word_locations = serialize_entity_annotations_to_json(word_locations)
    with word_locations_file.open("w", encoding="utf-8") as f:
        f.write(serialized_word_locations)

    # Save images
    for i, annotated_image in enumerate(annotated_images):
        annotated_file = images_dir / f"annotated_page_{i + 1}.png"
        annotated_image.save(annotated_file)
    for i, unannotated_image in enumerate(unannotated_images):
        unannotated_file = images_dir / f"unannotated_page_{i + 1}.png"
        unannotated_image.save(unannotated_file)

    # Save metadata
    metadata = {
        "source_pdf": journal_name,
        "page_count": len(text_pages),
        "images_directory": str(
            images_dir
        ),  # Convert Path to string for JSON serialization
        "files": {
            "text_pages": "text_pages.json",
            "word_locations": "word_locations.json",
        },
    }
    metadata_file = base_path / "metadata.json"
    with metadata_file.open("w", encoding="utf-8") as f:
        json.dump(metadata, f, indent=4)

    print(f"Processed data saved in: {base_path}")

serialize_entity_annotations_to_json(annotations)

Serializes a nested list of EntityAnnotation objects into a JSON-compatible format using Base64 encoding.

Parameters:

Name Type Description Default
annotations List[List[EntityAnnotation]]

The nested list of EntityAnnotation objects.

required

Returns:

Name Type Description
str str

The serialized data in JSON format as a string.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
def serialize_entity_annotations_to_json(
    annotations: List[List[EntityAnnotation]],
) -> str:
    """
    Serializes a nested list of EntityAnnotation objects into a JSON-compatible format using Base64 encoding.

    Parameters:
        annotations (List[List[EntityAnnotation]]): The nested list of EntityAnnotation objects.

    Returns:
        str: The serialized data in JSON format as a string.
    """
    serialized_data = []
    for page_annotations in annotations:
        serialized_page = [
            base64.b64encode(annotation.SerializeToString()).decode("utf-8")
            for annotation in page_annotations
        ]
        serialized_data.append(serialized_page)

    # Convert to a JSON string
    return json.dumps(serialized_data, indent=4)

start_image_annotator_client(credentials_file=None, api_endpoint='vision.googleapis.com', timeout=(10, 30), enable_logging=False)

Starts and returns a Google Vision API ImageAnnotatorClient with optional configuration.

Parameters:

Name Type Description Default
credentials_file str

Path to the credentials JSON file. If None, uses the default environment variable.

None
api_endpoint str

Custom API endpoint for the Vision API. Default is the global endpoint.

'vision.googleapis.com'
timeout Tuple[int, int]

Connection and read timeouts in seconds. Default is (10, 30).

(10, 30)
enable_logging bool

Enable detailed logging for debugging. Default is False.

False

Returns:

Type Description
ImageAnnotatorClient

vision.ImageAnnotatorClient: Configured Vision API client.

Raises:

Type Description
FileNotFoundError

If the specified credentials file is not found.

Exception

For unexpected errors during client setup.

Example

client = start_image_annotator_client( credentials_file="/path/to/credentials.json", api_endpoint="vision.googleapis.com", timeout=(10, 30), enable_logging=True ) print("Google Vision API client initialized.")

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def start_image_annotator_client(
    credentials_file: str = None,
    api_endpoint: str = "vision.googleapis.com",
    timeout: Tuple[int, int] = (10, 30),
    enable_logging: bool = False,
) -> vision.ImageAnnotatorClient:
    """
    Starts and returns a Google Vision API ImageAnnotatorClient with optional configuration.

    Parameters:
        credentials_file (str): Path to the credentials JSON file. If None, uses the default environment variable.
        api_endpoint (str): Custom API endpoint for the Vision API. Default is the global endpoint.
        timeout (Tuple[int, int]): Connection and read timeouts in seconds. Default is (10, 30).
        enable_logging (bool): Enable detailed logging for debugging. Default is False.

    Returns:
        vision.ImageAnnotatorClient: Configured Vision API client.

    Raises:
        FileNotFoundError: If the specified credentials file is not found.
        Exception: For unexpected errors during client setup.

    Example:
        >>> client = start_image_annotator_client(
        >>>     credentials_file="/path/to/credentials.json",
        >>>     api_endpoint="vision.googleapis.com",
        >>>     timeout=(10, 30),
        >>>     enable_logging=True
        >>> )
        >>> print("Google Vision API client initialized.")
    """
    try:
        # Set up credentials
        if credentials_file:
            if not os.path.exists(credentials_file):
                raise FileNotFoundError(
                    f"Credentials file '{credentials_file}' not found."
                )
            os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credentials_file

        # Configure client options
        client_options = {"api_endpoint": api_endpoint}
        client = vision.ImageAnnotatorClient(client_options=client_options)

        # Optionally enable logging
        if enable_logging:
            print(f"Vision API Client started with endpoint: {api_endpoint}")
            print(f"Timeout settings: Connect={timeout[0]}s, Read={timeout[1]}s")

        return client

    except Exception as e:
        raise Exception(f"Failed to initialize ImageAnnotatorClient: {e}")

ocr_editor

current_image = st.session_state.current_image module-attribute
current_page_index = st.session_state.current_page_index module-attribute
current_text = pages[current_page_index] module-attribute
edited_text = st.text_area('Edit OCR Text', value=(st.session_state.current_text), key=f'text_area_{st.session_state.current_page_index}', height=400) module-attribute
image_directory = st.sidebar.text_input('Image Directory', value='./images') module-attribute
ocr_text_directory = st.sidebar.text_input('OCR Text Directory', value='./ocr_text') module-attribute
pages = st.session_state.pages module-attribute
save_path = os.path.join(ocr_text_directory, 'updated_ocr.xml') module-attribute
tree = st.session_state.tree module-attribute
uploaded_image_file = st.sidebar.file_uploader('Upload an Image', type=['jpg', 'jpeg', 'png', 'pdf']) module-attribute
uploaded_text_file = st.sidebar.file_uploader('Upload OCR Text File', type=['xml']) module-attribute
extract_pages(tree)

Extract page data from the XML tree.

Parameters:

Name Type Description Default
tree ElementTree

Parsed XML tree.

required

Returns:

Name Type Description
list

A list of dictionaries containing 'number' and 'text' for each page.

Source code in src/tnh_scholar/ocr_processing/ocr_editor.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def extract_pages(tree):
    """
    Extract page data from the XML tree.

    Args:
        tree (etree.ElementTree): Parsed XML tree.

    Returns:
        list: A list of dictionaries containing 'number' and 'text' for each page.
    """
    pages = []
    for page in tree.xpath("//page"):
        page_number = page.get("page")
        ocr_text = page.text.strip() if page.text else ""
        pages.append({"number": page_number, "text": ocr_text})
    return pages
load_xml(file_obj)

Load an XML file from a file-like object.

Source code in src/tnh_scholar/ocr_processing/ocr_editor.py
28
29
30
31
32
33
34
35
36
37
def load_xml(file_obj):
    """
    Load an XML file from a file-like object.
    """
    try:
        tree = etree.parse(file_obj)  # Directly parse the file-like object
        return tree
    except etree.XMLSyntaxError as e:
        st.error(f"Error parsing XML file: {e}")
        return None
save_xml(tree, file_path)

Save the modified XML tree to a file.

Source code in src/tnh_scholar/ocr_processing/ocr_editor.py
41
42
43
44
45
46
def save_xml(tree, file_path):
    """
    Save the modified XML tree to a file.
    """
    with open(file_path, "wb") as file:
        tree.write(file, pretty_print=True, encoding="utf-8", xml_declaration=True)

ocr_processing

DEFAULT_ANNOTATION_FONT_PATH = Path('/System/Library/Fonts/Supplemental/Arial.ttf') module-attribute
DEFAULT_ANNOTATION_FONT_SIZE = 12 module-attribute
DEFAULT_ANNOTATION_LANGUAGE_HINTS = ['vi'] module-attribute
DEFAULT_ANNOTATION_METHOD = 'DOCUMENT_TEXT_DETECTION' module-attribute
DEFAULT_ANNOTATION_OFFSET = 2 module-attribute
logger = logging.getLogger('ocr_processing') module-attribute
PDFParseWarning

Bases: Warning

Custom warning class for PDF parsing issues. Encapsulates minimal logic for displaying warnings with a custom format.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class PDFParseWarning(Warning):
    """
    Custom warning class for PDF parsing issues.
    Encapsulates minimal logic for displaying warnings with a custom format.
    """

    @staticmethod
    def warn(message: str):
        """
        Display a warning message with custom formatting.

        Parameters:
            message (str): The warning message to display.
        """
        formatted_message = f"\033[93mPDFParseWarning: {message}\033[0m"
        print(formatted_message)  # Simply prints the warning
warn(message) staticmethod

Display a warning message with custom formatting.

Parameters:

Name Type Description Default
message str

The warning message to display.

required
Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
30
31
32
33
34
35
36
37
38
39
@staticmethod
def warn(message: str):
    """
    Display a warning message with custom formatting.

    Parameters:
        message (str): The warning message to display.
    """
    formatted_message = f"\033[93mPDFParseWarning: {message}\033[0m"
    print(formatted_message)  # Simply prints the warning
annotate_image_with_text(image, text_annotations, annotation_font_path, font_size=12)

Annotates a PIL image with bounding boxes and text descriptions from OCR results.

Parameters:

Name Type Description Default
pil_image Image

The input PIL image to annotate.

required
text_annotations List[EntityAnnotation]

OCR results containing bounding boxes and text.

required
annotation_font_path str

Path to the font file for text annotations.

required
font_size int

Font size for text annotations.

12

Returns:

Type Description
Image

Image.Image: The annotated PIL image.

Raises:

Type Description
ValueError

If the input image is None.

IOError

If the font file cannot be loaded.

Exception

For any other unexpected errors.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def annotate_image_with_text(
    image: Image.Image,
    text_annotations: List[EntityAnnotation],
    annotation_font_path: str,
    font_size: int = 12,
) -> Image.Image:
    """
    Annotates a PIL image with bounding boxes and text descriptions from OCR results.

    Parameters:
        pil_image (Image.Image): The input PIL image to annotate.
        text_annotations (List[EntityAnnotation]): OCR results containing bounding boxes and text.
        annotation_font_path (str): Path to the font file for text annotations.
        font_size (int): Font size for text annotations.

    Returns:
        Image.Image: The annotated PIL image.

    Raises:
        ValueError: If the input image is None.
        IOError: If the font file cannot be loaded.
        Exception: For any other unexpected errors.
    """
    if image is None:
        raise ValueError("The input image is None.")

    try:
        font = ImageFont.truetype(annotation_font_path, font_size)
    except IOError as e:
        raise IOError(f"Failed to load the font from '{annotation_font_path}': {e}")

    draw = ImageDraw.Draw(image)

    try:
        for i, text_obj in enumerate(text_annotations):
            vertices = [
                (vertex.x, vertex.y) for vertex in text_obj.bounding_poly.vertices
            ]
            if (
                len(vertices) == 4
            ):  # Ensure there are exactly 4 vertices for a rectangle
                # Draw the bounding box
                draw.polygon(vertices, outline="red", width=2)

                # Skip the first bounding box (whole text region)
                if i > 0:
                    # Offset the text position slightly for clarity
                    text_position = (vertices[0][0] + 2, vertices[0][1] + 2)
                    draw.text(
                        text_position, text_obj.description, fill="red", font=font
                    )

    except AttributeError as e:
        raise ValueError(f"Invalid text annotation structure: {e}")
    except Exception as e:
        raise Exception(f"An error occurred during image annotation: {e}")

    return image
build_processed_pdf(pdf_path, client, preprocessor=None, annotation_font_path=DEFAULT_ANNOTATION_FONT_PATH)

Processes a PDF document, extracting text, word locations, annotated images, and unannotated images.

Parameters:

Name Type Description Default
pdf_path Path

Path to the PDF file.

required
client ImageAnnotatorClient

Google Vision API client for text detection.

required
annotation_font_path Path

Path to the font file for annotations.

DEFAULT_ANNOTATION_FONT_PATH

Returns:

Type Description
Tuple[List[str], List[List[EntityAnnotation]], List[Image], List[Image]]

Tuple[List[str], List[List[vision.EntityAnnotation]], List[Image.Image], List[Image.Image]]: - List of extracted full-page texts (one entry per page). - List of word locations (list of vision.EntityAnnotation objects for each page). - List of annotated images (with bounding boxes and text annotations). - List of unannotated images (raw page images).

Raises:

Type Description
FileNotFoundError

If the specified PDF file does not exist.

ValueError

If the PDF file is invalid or contains no pages.

Exception

For any unexpected errors during processing.

Example

from pathlib import Path from google.cloud import vision pdf_path = Path("/path/to/example.pdf") font_path = Path("/path/to/fonts/Arial.ttf") client = vision.ImageAnnotatorClient() try: text_pages, word_locations_list, annotated_images, unannotated_images = build_processed_pdf( pdf_path, client, font_path ) print(f"Processed {len(text_pages)} pages successfully!") except Exception as e: print(f"Error processing PDF: {e}")

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
def build_processed_pdf(
    pdf_path: Path,
    client: vision.ImageAnnotatorClient,
    preprocessor: Callable = None,
    annotation_font_path: Path = DEFAULT_ANNOTATION_FONT_PATH,
) -> Tuple[
    List[str], List[List[vision.EntityAnnotation]], List[Image.Image], List[Image.Image]
]:
    """
    Processes a PDF document, extracting text, word locations, annotated images, and unannotated images.

    Parameters:
        pdf_path (Path): Path to the PDF file.
        client (vision.ImageAnnotatorClient): Google Vision API client for text detection.
        annotation_font_path (Path): Path to the font file for annotations.

    Returns:
        Tuple[List[str], List[List[vision.EntityAnnotation]], List[Image.Image], List[Image.Image]]:
            - List of extracted full-page texts (one entry per page).
            - List of word locations (list of `vision.EntityAnnotation` objects for each page).
            - List of annotated images (with bounding boxes and text annotations).
            - List of unannotated images (raw page images).

    Raises:
        FileNotFoundError: If the specified PDF file does not exist.
        ValueError: If the PDF file is invalid or contains no pages.
        Exception: For any unexpected errors during processing.

    Example:
        >>> from pathlib import Path
        >>> from google.cloud import vision
        >>> pdf_path = Path("/path/to/example.pdf")
        >>> font_path = Path("/path/to/fonts/Arial.ttf")
        >>> client = vision.ImageAnnotatorClient()
        >>> try:
        >>>     text_pages, word_locations_list, annotated_images, unannotated_images = build_processed_pdf(
        >>>         pdf_path, client, font_path
        >>>     )
        >>>     print(f"Processed {len(text_pages)} pages successfully!")
        >>> except Exception as e:
        >>>     print(f"Error processing PDF: {e}")
    """
    try:
        doc = load_pdf_pages(pdf_path)
    except FileNotFoundError as fnf_error:
        raise FileNotFoundError(f"Error loading PDF: {fnf_error}")
    except ValueError as ve:
        raise ValueError(f"Invalid PDF file: {ve}")
    except Exception as e:
        raise Exception(f"An unexpected error occurred while loading the PDF: {e}")

    if doc.page_count == 0:
        raise ValueError(f"The PDF file '{pdf_path}' contains no pages.")

    logger.info(f"Processing file with {doc.page_count} pages:\n\t{pdf_path}")

    text_pages = []
    word_locations_list = []
    annotated_images = []
    unannotated_images = []
    first_page_dimensions = None

    for page_num in range(doc.page_count):
        logger.info(f"Processing page {page_num + 1}/{doc.page_count}...")

        try:
            page = doc.load_page(page_num)
            (
                full_page_text,
                word_locations,
                annotated_image,
                unannotated_image,
                page_dimensions,
            ) = process_page(page, client, annotation_font_path, preprocessor)

            if full_page_text:  # this is not an empty page

                if page_num == 0:  # save first page info
                    first_page_dimensions = page_dimensions
                elif (
                    page_dimensions != first_page_dimensions
                ):  # verify page dimensions are consistent
                    PDFParseWarning.warn(
                        f"Page {page_num + 1} has different dimensions than page 1."
                        f"({page_dimensions}) compared to the first page: ({first_page_dimensions})."
                    )

                text_pages.append(full_page_text)
                word_locations_list.append(word_locations)
                annotated_images.append(annotated_image)
                unannotated_images.append(unannotated_image)
            else:
                PDFParseWarning.warn(
                    f"Page {page_num + 1} empty, added empty datastructures...\n"
                    # f"  (Note that total document length will be reduced.)"
                )

        except ValueError as ve:
            print(f"ValueError on page {page_num + 1}: {ve}")
        except OSError as oe:
            print(f"OSError on page {page_num + 1}: {oe}")
        except Exception as e:
            print(f"Unexpected error on page {page_num + 1}: {e}")

    print(f"page dimensions: {page_dimensions}")
    return text_pages, word_locations_list, annotated_images, unannotated_images
deserialize_entity_annotations_from_json(data)

Deserializes JSON data into a nested list of EntityAnnotation objects.

Parameters:

Name Type Description Default
data str

The JSON string containing serialized annotations.

required

Returns:

Type Description
List[List[EntityAnnotation]]

List[List[EntityAnnotation]]: The reconstructed nested list of EntityAnnotation objects.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
def deserialize_entity_annotations_from_json(data: str) -> List[List[EntityAnnotation]]:
    """
    Deserializes JSON data into a nested list of EntityAnnotation objects.

    Parameters:
        data (str): The JSON string containing serialized annotations.

    Returns:
        List[List[EntityAnnotation]]: The reconstructed nested list of EntityAnnotation objects.
    """
    serialized_data = json.loads(data)
    deserialized_data = []

    for serialized_page in serialized_data:
        page_annotations = [
            EntityAnnotation.deserialize(base64.b64decode(serialized_annotation))
            for serialized_annotation in serialized_page
        ]
        deserialized_data.append(page_annotations)

    return deserialized_data
extract_image_from_page(page)

Extracts the first image from the given PDF page and returns it as a PIL Image.

Parameters:

Name Type Description Default
page Page

The PDF page object.

required

Returns:

Type Description
Image

Image.Image: The first image on the page as a Pillow Image object.

Raises:

Type Description
ValueError

If no images are found on the page or the image data is incomplete.

Exception

For unexpected errors during image extraction.

Example

import fitz from PIL import Image doc = fitz.open("/path/to/document.pdf") page = doc.load_page(0) # Load the first page try: image = extract_image_from_page(page) image.show() # Display the image except Exception as e: print(f"Error extracting image: {e}")

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def extract_image_from_page(page: fitz.Page) -> Image.Image:
    """
    Extracts the first image from the given PDF page and returns it as a PIL Image.

    Parameters:
        page (fitz.Page): The PDF page object.

    Returns:
        Image.Image: The first image on the page as a Pillow Image object.

    Raises:
        ValueError: If no images are found on the page or the image data is incomplete.
        Exception: For unexpected errors during image extraction.

    Example:
        >>> import fitz
        >>> from PIL import Image
        >>> doc = fitz.open("/path/to/document.pdf")
        >>> page = doc.load_page(0)  # Load the first page
        >>> try:
        >>>     image = extract_image_from_page(page)
        >>>     image.show()  # Display the image
        >>> except Exception as e:
        >>>     print(f"Error extracting image: {e}")
    """
    try:
        # Get images from the page
        images = page.get_images(full=True)
        if not images:
            raise ValueError("No images found on the page.")

        # Extract the first image reference
        xref = images[0][0]  # Get the first image's xref
        base_image = page.parent.extract_image(xref)

        # Validate the extracted image data
        if (
            "image" not in base_image
            or "width" not in base_image
            or "height" not in base_image
        ):
            raise ValueError("The extracted image data is incomplete.")

        # Convert the raw image bytes into a Pillow image
        image_bytes = base_image["image"]
        pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

        return pil_image

    except ValueError as ve:
        raise ve  # Re-raise for calling functions to handle
    except Exception as e:
        raise Exception(f"An unexpected error occurred during image extraction: {e}")
get_page_dimensions(page)

Extracts the width and height of a single PDF page in both inches and pixels.

Parameters:

Name Type Description Default
page Page

A single PDF page object from PyMuPDF.

required

Returns:

Name Type Description
dict dict

A dictionary containing the width and height of the page in inches and pixels.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def get_page_dimensions(page: fitz.Page) -> dict:
    """
    Extracts the width and height of a single PDF page in both inches and pixels.

    Args:
        page (fitz.Page): A single PDF page object from PyMuPDF.

    Returns:
        dict: A dictionary containing the width and height of the page in inches and pixels.
    """
    # Get page dimensions in points and convert to inches
    page_width_pts, page_height_pts = page.rect.width, page.rect.height
    page_width_in = page_width_pts / 72  # Convert points to inches
    page_height_in = page_height_pts / 72

    # Extract the first image on the page (if any) to get pixel dimensions
    images = page.get_images(full=True)
    if images:
        xref = images[0][0]
        base_image = page.parent.extract_image(xref)
        width_px = base_image["width"]
        height_px = base_image["height"]
    else:
        width_px, height_px = None, None  # No image found on the page

    # Return dimensions
    return {
        "width_in": page_width_in,
        "height_in": page_height_in,
        "width_px": width_px,
        "height_px": height_px,
    }
load_pdf_pages(pdf_path)

Opens the PDF document and returns the fitz Document object.

Parameters:

Name Type Description Default
pdf_path Path

The path to the PDF file.

required

Returns:

Type Description
Document

fitz.Document: The loaded PDF document.

Raises:

Type Description
FileNotFoundError

If the specified file does not exist.

ValueError

If the file is not a valid PDF document.

Exception

For any unexpected error.

Example

from pathlib import Path pdf_path = Path("/path/to/example.pdf") try: pdf_doc = load_pdf_pages(pdf_path) print(f"PDF contains {pdf_doc.page_count} pages.") except Exception as e: print(f"Error loading PDF: {e}")

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def load_pdf_pages(pdf_path: Path) -> fitz.Document:
    """
    Opens the PDF document and returns the fitz Document object.

    Parameters:
        pdf_path (Path): The path to the PDF file.

    Returns:
        fitz.Document: The loaded PDF document.

    Raises:
        FileNotFoundError: If the specified file does not exist.
        ValueError: If the file is not a valid PDF document.
        Exception: For any unexpected error.

    Example:
        >>> from pathlib import Path
        >>> pdf_path = Path("/path/to/example.pdf")
        >>> try:
        >>>     pdf_doc = load_pdf_pages(pdf_path)
        >>>     print(f"PDF contains {pdf_doc.page_count} pages.")
        >>> except Exception as e:
        >>>     print(f"Error loading PDF: {e}")
    """
    if not pdf_path.exists():
        raise FileNotFoundError(f"The file '{pdf_path}' does not exist.")

    if not pdf_path.suffix.lower() == ".pdf":
        raise ValueError(
            f"The file '{pdf_path}' is not a valid PDF document (expected '.pdf')."
        )

    try:
        return fitz.open(str(pdf_path))  # PyMuPDF expects a string path
    except Exception as e:
        raise Exception(f"An unexpected error occurred while opening the PDF: {e}")
load_processed_PDF_data(base_path)

Loads processed PDF data from files using metadata for file references.

Parameters:

Name Type Description Default
output_dir Path

Directory where the data is stored (as a Path object).

required
base_name str

Base name of the processed directory.

required

Returns:

Type Description
Tuple[List[str], List[List[EntityAnnotation]], List[Image], List[Image]]

Tuple[List[str], List[List[EntityAnnotation]], List[Image.Image], List[Image.Image]]: - Loaded text pages. - Word locations (list of EntityAnnotation objects for each page). - Annotated images. - Unannotated images.

Raises:

Type Description
FileNotFoundError

If any required files are missing.

ValueError

If the metadata file is incomplete or invalid.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
def load_processed_PDF_data(
    base_path: Path,
) -> Tuple[
    List[str], List[List[EntityAnnotation]], List[Image.Image], List[Image.Image]
]:
    """
    Loads processed PDF data from files using metadata for file references.

    Parameters:
        output_dir (Path): Directory where the data is stored (as a Path object).
        base_name (str): Base name of the processed directory.

    Returns:
        Tuple[List[str], List[List[EntityAnnotation]], List[Image.Image], List[Image.Image]]:
            - Loaded text pages.
            - Word locations (list of `EntityAnnotation` objects for each page).
            - Annotated images.
            - Unannotated images.

    Raises:
        FileNotFoundError: If any required files are missing.
        ValueError: If the metadata file is incomplete or invalid.
    """
    metadata_file = base_path / "metadata.json"

    # Load metadata
    try:
        with metadata_file.open("r", encoding="utf-8") as f:
            metadata = json.load(f)
    except FileNotFoundError:
        raise FileNotFoundError(f"Metadata file '{metadata_file}' not found.")
    except json.JSONDecodeError as e:
        raise ValueError(f"Invalid metadata file format: {e}")

    # Extract file paths from metadata
    text_pages_file = base_path / metadata.get("files", {}).get(
        "text_pages", "text_pages.json"
    )
    word_locations_file = base_path / metadata.get("files", {}).get(
        "word_locations", "word_locations.json"
    )
    images_dir = Path(metadata.get("images_directory", base_path / "images"))

    # Validate file paths
    if not text_pages_file.exists():
        raise FileNotFoundError(f"Text pages file '{text_pages_file}' not found.")
    if not word_locations_file.exists():
        raise FileNotFoundError(
            f"Word locations file '{word_locations_file}' not found."
        )
    if not images_dir.exists() or not images_dir.is_dir():
        raise FileNotFoundError(f"Images directory '{images_dir}' not found.")

    # Load text pages
    with text_pages_file.open("r", encoding="utf-8") as f:
        text_pages = json.load(f)

    # Load word locations
    with word_locations_file.open("r", encoding="utf-8") as f:
        serialized_word_locations = f.read()
        word_locations = deserialize_entity_annotations_from_json(
            serialized_word_locations
        )

    # Load images
    annotated_images = []
    unannotated_images = []
    for file in sorted(
        images_dir.iterdir()
    ):  # Iterate over files in the images directory
        if file.name.startswith("annotated_page_") and file.suffix == ".png":
            annotated_images.append(Image.open(file))
        elif file.name.startswith("unannotated_page_") and file.suffix == ".png":
            unannotated_images.append(Image.open(file))

    # Ensure images were loaded correctly
    if not annotated_images or not unannotated_images:
        raise ValueError(f"No images found in the directory '{images_dir}'.")

    return text_pages, word_locations, annotated_images, unannotated_images
make_image_preprocess_mask(mask_height)

Creates a preprocessing function that masks a specified height at the bottom of the image.

Parameters:

Name Type Description Default
mask_height float

The proportion of the image height to mask at the bottom (0.0 to 1.0).

required

Returns:

Type Description
Callable[[Image, int], Image]

Callable[[Image.Image, int], Image.Image]: A preprocessing function that takes an image

Callable[[Image, int], Image]

and page number as input and returns the processed image.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def make_image_preprocess_mask(
    mask_height: float,
) -> Callable[[Image.Image, int], Image.Image]:
    """
    Creates a preprocessing function that masks a specified height at the bottom of the image.

    Parameters:
        mask_height (float): The proportion of the image height to mask at the bottom (0.0 to 1.0).

    Returns:
        Callable[[Image.Image, int], Image.Image]: A preprocessing function that takes an image
        and page number as input and returns the processed image.
    """

    def pre_process_image(image: Image.Image, page_number: int) -> Image.Image:
        """
        Preprocesses the image by masking the bottom region or performing other preprocessing steps.

        Parameters:
            image (Image.Image): The input image as a Pillow object.
            page_number (int): The page number of the image (useful for conditional preprocessing).

        Returns:
            Image.Image: The preprocessed image.
        """

        if page_number > 0:  # don't apply mask to cover page.
            # Make a copy of the image to avoid modifying the original
            draw = ImageDraw.Draw(image)

            # Get image dimensions
            width, height = image.size

            # Mask the bottom region based on the specified height proportion
            mask_pixels = int(height * mask_height)
            draw.rectangle([(0, height - mask_pixels), (width, height)], fill="black")

        return image

    return pre_process_image
pil_to_bytes(image, format='PNG')

Converts a Pillow image to raw bytes.

Parameters:

Name Type Description Default
image Image

The Pillow image object to convert.

required
format str

The format to save the image as (e.g., "PNG", "JPEG"). Default is "PNG".

'PNG'

Returns:

Name Type Description
bytes bytes

The raw bytes of the image.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def pil_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
    """
    Converts a Pillow image to raw bytes.

    Parameters:
        image (Image.Image): The Pillow image object to convert.
        format (str): The format to save the image as (e.g., "PNG", "JPEG"). Default is "PNG".

    Returns:
        bytes: The raw bytes of the image.
    """
    with io.BytesIO() as output:
        image.save(output, format=format)
        return output.getvalue()
process_page(page, client, annotation_font_path, preprocessor=None)

Processes a single PDF page, extracting text, word locations, and annotated images.

Parameters:

Name Type Description Default
page Page

The PDF page object.

required
client ImageAnnotatorClient

Google Vision API client for text detection.

required
pre_processor Callable[[Image, int], Image]

Preprocessing function for the image.

required
annotation_font_path str

Path to the font file for annotations.

required

Returns:

Type Description
Tuple[str, List[EntityAnnotation], Image, Image, dict]

Tuple[str, List[vision.EntityAnnotation], Image.Image, Image.Image, dict]: - Full page text (str) - Word locations (List of vision.EntityAnnotation) - Annotated image (Pillow Image object) - Original unprocessed image (Pillow Image object) - Page dimensions (dict)

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
def process_page(
    page: fitz.Page,
    client: vision.ImageAnnotatorClient,
    annotation_font_path: str,
    preprocessor: Callable[[Image.Image, int], Image.Image] = None,
) -> Tuple[str, List[vision.EntityAnnotation], Image.Image, Image.Image, dict]:
    """
    Processes a single PDF page, extracting text, word locations, and annotated images.

    Parameters:
        page (fitz.Page): The PDF page object.
        client (vision.ImageAnnotatorClient): Google Vision API client for text detection.
        pre_processor (Callable[[Image.Image, int], Image.Image]): Preprocessing function for the image.
        annotation_font_path (str): Path to the font file for annotations.

    Returns:
        Tuple[str, List[vision.EntityAnnotation], Image.Image, Image.Image, dict]:
            - Full page text (str)
            - Word locations (List of vision.EntityAnnotation)
            - Annotated image (Pillow Image object)
            - Original unprocessed image (Pillow Image object)
            - Page dimensions (dict)
    """
    # Extract the original image from the PDF page
    original_image = extract_image_from_page(page)

    # Make a copy of the original image for processing
    processed_image = original_image.copy()

    # Apply the preprocessing function (if provided)
    if preprocessor:
        # print("preprocessing...") # debug
        processed_image = preprocessor(processed_image, page.number)
        # processed_image.show() # debug

    # Annotate the processed image using the Vision API
    response = process_single_image(processed_image, client)

    if response:
        text_annotations = response.text_annotations
        # Extract full text and word locations
        full_page_text = text_annotations[0].description if text_annotations else ""
        word_locations = text_annotations[1:] if len(text_annotations) > 1 else []
    else:
        # return empty data
        full_page_text = ""
        word_locations = [EntityAnnotation()]
        text_annotations = [
            EntityAnnotation()
        ]  # create empty data structures to allow storing to proceed.

    # Create an annotated image with bounding boxes and labels
    annotated_image = annotate_image_with_text(
        processed_image, text_annotations, annotation_font_path
    )

    # Get page dimensions (from the original PDF page, not the image)
    page_dimensions = get_page_dimensions(page)

    return (
        full_page_text,
        word_locations,
        annotated_image,
        original_image,
        page_dimensions,
    )
process_single_image(image, client, feature_type=DEFAULT_ANNOTATION_METHOD, language_hints=DEFAULT_ANNOTATION_LANGUAGE_HINTS)

Processes a single image with the Google Vision API and returns text annotations.

Parameters:

Name Type Description Default
image Image

The preprocessed Pillow image object.

required
client ImageAnnotatorClient

Google Vision API client for text detection.

required
feature_type str

Type of text detection to use ('TEXT_DETECTION' or 'DOCUMENT_TEXT_DETECTION').

DEFAULT_ANNOTATION_METHOD
language_hints List

Language hints for OCR.

DEFAULT_ANNOTATION_LANGUAGE_HINTS

Returns:

Type Description
List[EntityAnnotation]

List[vision.EntityAnnotation]: Text annotations from the Vision API response.

Raises:

Type Description
ValueError

If no text is detected.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
def process_single_image(
    image: Image.Image,
    client: vision.ImageAnnotatorClient,
    feature_type: str = DEFAULT_ANNOTATION_METHOD,
    language_hints: List = DEFAULT_ANNOTATION_LANGUAGE_HINTS,
) -> List[vision.EntityAnnotation]:
    """
    Processes a single image with the Google Vision API and returns text annotations.

    Parameters:
        image (Image.Image): The preprocessed Pillow image object.
        client (vision.ImageAnnotatorClient): Google Vision API client for text detection.
        feature_type (str): Type of text detection to use ('TEXT_DETECTION' or 'DOCUMENT_TEXT_DETECTION').
        language_hints (List): Language hints for OCR.

    Returns:
        List[vision.EntityAnnotation]: Text annotations from the Vision API response.

    Raises:
        ValueError: If no text is detected.
    """
    # Convert the Pillow image to bytes
    image_bytes = pil_to_bytes(image, format="PNG")

    # Map feature type
    feature_map = {
        "TEXT_DETECTION": vision.Feature.Type.TEXT_DETECTION,
        "DOCUMENT_TEXT_DETECTION": vision.Feature.Type.DOCUMENT_TEXT_DETECTION,
    }
    if feature_type not in feature_map:
        raise ValueError(
            f"Invalid feature type '{feature_type}'. Use 'TEXT_DETECTION' or 'DOCUMENT_TEXT_DETECTION'."
        )

    # Prepare Vision API request
    vision_image = vision.Image(content=image_bytes)
    features = [vision.Feature(type=feature_map[feature_type])]
    image_context = vision.ImageContext(language_hints=language_hints)

    # Make the API call
    response = client.annotate_image(
        {"image": vision_image, "features": features, "image_context": image_context}
    )

    return response
save_processed_pdf_data(output_dir, journal_name, text_pages, word_locations, annotated_images, unannotated_images)

Saves processed PDF data to files for later reloading.

Parameters:

Name Type Description Default
output_dir Path

Directory to save the data (as a Path object).

required
base_name str

Base name for the output directory (usually the PDF name without extension).

required
text_pages List[str]

Extracted full-page text.

required
word_locations List[List[EntityAnnotation]]

Word locations and annotations from Vision API.

required
annotated_images List[Image]

Annotated images with bounding boxes.

required
unannotated_images List[Image]

Raw unannotated images.

required

Returns:

Type Description

None

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
def save_processed_pdf_data(
    output_dir: Path,
    journal_name: str,
    text_pages: List[str],
    word_locations: List[List[EntityAnnotation]],
    annotated_images: List[Image.Image],
    unannotated_images: List[Image.Image],
):
    """
    Saves processed PDF data to files for later reloading.

    Parameters:
        output_dir (Path): Directory to save the data (as a Path object).
        base_name (str): Base name for the output directory (usually the PDF name without extension).
        text_pages (List[str]): Extracted full-page text.
        word_locations (List[List[EntityAnnotation]]): Word locations and annotations from Vision API.
        annotated_images (List[PIL.Image.Image]): Annotated images with bounding boxes.
        unannotated_images (List[PIL.Image.Image]): Raw unannotated images.

    Returns:
        None
    """
    # Create output directories
    base_path = output_dir / journal_name / "ocr_data"
    images_dir = base_path / "images"

    base_path.mkdir(parents=True, exist_ok=True)
    images_dir.mkdir(parents=True, exist_ok=True)

    # Save text data
    text_pages_file = base_path / "text_pages.json"
    with text_pages_file.open("w", encoding="utf-8") as f:
        json.dump(text_pages, f, indent=4, ensure_ascii=False)

    # Save word locations as JSON
    word_locations_file = base_path / "word_locations.json"
    serialized_word_locations = serialize_entity_annotations_to_json(word_locations)
    with word_locations_file.open("w", encoding="utf-8") as f:
        f.write(serialized_word_locations)

    # Save images
    for i, annotated_image in enumerate(annotated_images):
        annotated_file = images_dir / f"annotated_page_{i + 1}.png"
        annotated_image.save(annotated_file)
    for i, unannotated_image in enumerate(unannotated_images):
        unannotated_file = images_dir / f"unannotated_page_{i + 1}.png"
        unannotated_image.save(unannotated_file)

    # Save metadata
    metadata = {
        "source_pdf": journal_name,
        "page_count": len(text_pages),
        "images_directory": str(
            images_dir
        ),  # Convert Path to string for JSON serialization
        "files": {
            "text_pages": "text_pages.json",
            "word_locations": "word_locations.json",
        },
    }
    metadata_file = base_path / "metadata.json"
    with metadata_file.open("w", encoding="utf-8") as f:
        json.dump(metadata, f, indent=4)

    print(f"Processed data saved in: {base_path}")
serialize_entity_annotations_to_json(annotations)

Serializes a nested list of EntityAnnotation objects into a JSON-compatible format using Base64 encoding.

Parameters:

Name Type Description Default
annotations List[List[EntityAnnotation]]

The nested list of EntityAnnotation objects.

required

Returns:

Name Type Description
str str

The serialized data in JSON format as a string.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
def serialize_entity_annotations_to_json(
    annotations: List[List[EntityAnnotation]],
) -> str:
    """
    Serializes a nested list of EntityAnnotation objects into a JSON-compatible format using Base64 encoding.

    Parameters:
        annotations (List[List[EntityAnnotation]]): The nested list of EntityAnnotation objects.

    Returns:
        str: The serialized data in JSON format as a string.
    """
    serialized_data = []
    for page_annotations in annotations:
        serialized_page = [
            base64.b64encode(annotation.SerializeToString()).decode("utf-8")
            for annotation in page_annotations
        ]
        serialized_data.append(serialized_page)

    # Convert to a JSON string
    return json.dumps(serialized_data, indent=4)
start_image_annotator_client(credentials_file=None, api_endpoint='vision.googleapis.com', timeout=(10, 30), enable_logging=False)

Starts and returns a Google Vision API ImageAnnotatorClient with optional configuration.

Parameters:

Name Type Description Default
credentials_file str

Path to the credentials JSON file. If None, uses the default environment variable.

None
api_endpoint str

Custom API endpoint for the Vision API. Default is the global endpoint.

'vision.googleapis.com'
timeout Tuple[int, int]

Connection and read timeouts in seconds. Default is (10, 30).

(10, 30)
enable_logging bool

Enable detailed logging for debugging. Default is False.

False

Returns:

Type Description
ImageAnnotatorClient

vision.ImageAnnotatorClient: Configured Vision API client.

Raises:

Type Description
FileNotFoundError

If the specified credentials file is not found.

Exception

For unexpected errors during client setup.

Example

client = start_image_annotator_client( credentials_file="/path/to/credentials.json", api_endpoint="vision.googleapis.com", timeout=(10, 30), enable_logging=True ) print("Google Vision API client initialized.")

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def start_image_annotator_client(
    credentials_file: str = None,
    api_endpoint: str = "vision.googleapis.com",
    timeout: Tuple[int, int] = (10, 30),
    enable_logging: bool = False,
) -> vision.ImageAnnotatorClient:
    """
    Starts and returns a Google Vision API ImageAnnotatorClient with optional configuration.

    Parameters:
        credentials_file (str): Path to the credentials JSON file. If None, uses the default environment variable.
        api_endpoint (str): Custom API endpoint for the Vision API. Default is the global endpoint.
        timeout (Tuple[int, int]): Connection and read timeouts in seconds. Default is (10, 30).
        enable_logging (bool): Enable detailed logging for debugging. Default is False.

    Returns:
        vision.ImageAnnotatorClient: Configured Vision API client.

    Raises:
        FileNotFoundError: If the specified credentials file is not found.
        Exception: For unexpected errors during client setup.

    Example:
        >>> client = start_image_annotator_client(
        >>>     credentials_file="/path/to/credentials.json",
        >>>     api_endpoint="vision.googleapis.com",
        >>>     timeout=(10, 30),
        >>>     enable_logging=True
        >>> )
        >>> print("Google Vision API client initialized.")
    """
    try:
        # Set up credentials
        if credentials_file:
            if not os.path.exists(credentials_file):
                raise FileNotFoundError(
                    f"Credentials file '{credentials_file}' not found."
                )
            os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credentials_file

        # Configure client options
        client_options = {"api_endpoint": api_endpoint}
        client = vision.ImageAnnotatorClient(client_options=client_options)

        # Optionally enable logging
        if enable_logging:
            print(f"Vision API Client started with endpoint: {api_endpoint}")
            print(f"Timeout settings: Connect={timeout[0]}s, Read={timeout[1]}s")

        return client

    except Exception as e:
        raise Exception(f"Failed to initialize ImageAnnotatorClient: {e}")

text_processing

__all__ = ['bracket_lines', 'unbracket_lines', 'lines_from_bracketed_text', 'NumberedText', 'normalize_newlines', 'clean_text'] module-attribute

NumberedText

Represents a text document with numbered lines for easy reference and manipulation.

Provides utilities for working with line-numbered text including reading, writing, accessing lines by number, and iterating over numbered lines.

Attributes:

Name Type Description
lines List[str]

List of text lines

start int

Starting line number (default: 1)

separator str

Separator between line number and content (default: ": ")

Examples:

>>> text = "First line\nSecond line\n\nFourth line"
>>> doc = NumberedText(text)
>>> print(doc)
1: First line
2: Second line
3:
4: Fourth line
>>> print(doc.get_line(2))
Second line
>>> for num, line in doc:
...     print(f"Line {num}: {len(line)} chars")
Source code in src/tnh_scholar/text_processing/numbered_text.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
class NumberedText:
    """
    Represents a text document with numbered lines for easy reference and manipulation.

    Provides utilities for working with line-numbered text including reading,
    writing, accessing lines by number, and iterating over numbered lines.

    Attributes:
        lines (List[str]): List of text lines
        start (int): Starting line number (default: 1)
        separator (str): Separator between line number and content (default: ": ")

    Examples:
        >>> text = "First line\\nSecond line\\n\\nFourth line"
        >>> doc = NumberedText(text)
        >>> print(doc)
        1: First line
        2: Second line
        3:
        4: Fourth line

        >>> print(doc.get_line(2))
        Second line

        >>> for num, line in doc:
        ...     print(f"Line {num}: {len(line)} chars")
    """

    @dataclass
    class LineSegment:
        """
        Represents a segment of lines with start and end indices in 1-based indexing.

        The segment follows Python range conventions where start is inclusive and
        end is exclusive. However, indexing is 1-based to match NumberedText.

        Attributes:
            start: Starting line number (inclusive, 1-based)
            end: Ending line number (exclusive, 1-based)
        """

        start: int
        end: int

        def __iter__(self):
            """Allow unpacking into start, end pairs."""
            yield self.start
            yield self.end

    class SegmentIterator:
        """
        Iterator for generating line segments of specified size.

        Produces segments of lines with start/end indices following 1-based indexing.
        The final segment may be smaller than the specified segment size.

        Attributes:
            total_lines: Total number of lines in text
            segment_size: Number of lines per segment
            start_line: Starting line number (1-based)
            min_segment_size: Minimum size for the final segment
        """

        def __init__(
            self,
            total_lines: int,
            segment_size: int,
            start_line: int = 1,
            min_segment_size: Optional[int] = None,
        ):
            """
            Initialize the segment iterator.

            Args:
                total_lines: Total number of lines to iterate over
                segment_size: Desired size of each segment
                start_line: First line number (default: 1)
                min_segment_size: Minimum size for final segment (default: None)
                    If specified, the last segment will be merged with the previous one
                    if it would be smaller than this size.

            Raises:
                ValueError: If segment_size < 1 or total_lines < 1
                ValueError: If start_line < 1 (must use 1-based indexing)
                ValueError: If min_segment_size >= segment_size
            """
            if segment_size < 1:
                raise ValueError("Segment size must be at least 1")
            if total_lines < 1:
                raise ValueError("Total lines must be at least 1")
            if start_line < 1:
                raise ValueError("Start line must be at least 1 (1-based indexing)")
            if min_segment_size is not None and min_segment_size >= segment_size:
                raise ValueError("Minimum segment size must be less than segment size")

            self.total_lines = total_lines
            self.segment_size = segment_size
            self.start_line = start_line
            self.min_segment_size = min_segment_size

            # Calculate number of segments
            remaining_lines = total_lines - start_line + 1
            self.num_segments = (remaining_lines + segment_size - 1) // segment_size

        def __iter__(self) -> Iterator["NumberedText.LineSegment"]:
            """
            Iterate over line segments.

            Yields:
                LineSegment containing start (inclusive) and end (exclusive) indices
            """
            current = self.start_line

            for i in range(self.num_segments):
                is_last_segment = i == self.num_segments - 1
                segment_end = min(current + self.segment_size, self.total_lines + 1)

                # Handle minimum segment size for last segment
                if (
                    is_last_segment
                    and self.min_segment_size is not None
                    and segment_end - current < self.min_segment_size
                    and i > 0
                ):
                    # Merge with previous segment by not yielding
                    break

                yield NumberedText.LineSegment(current, segment_end)
                current = segment_end

    def __init__(
        self, content: Optional[str] = None, start: int = 1, separator: str = ":"
    ) -> None:
        """
        Initialize a numbered text document, 
        detecting and preserving existing numbering.

        Valid numbered text must have:
        - Sequential line numbers
        - Consistent separator character(s)
        - Every non-empty line must follow the numbering pattern

        Args:
            content: Initial text content, if any
            start: Starting line number (used only if content isn't already numbered)
            separator: Separator between line numbers and content 
            (only if content isn't numbered)

        Examples:
            >>> # Custom separators
            >>> doc = NumberedText("1→First line\\n2→Second line")
            >>> doc.separator == "→"
            True

            >>> # Preserves starting number
            >>> doc = NumberedText("5#First\\n6#Second")
            >>> doc.start == 5
            True

            >>> # Regular numbered list isn't treated as line numbers
            >>> doc = NumberedText("1. First item\\n2. Second item")
            >>> doc.numbered_lines
            ['1: 1. First item', '2: 2. Second item']
        """

        self.lines: List[str] = []  # Declare lines here
        self.start: int = start  # Declare start with its type
        self.separator: str = separator  # and separator

        if not isinstance(content, str):
            raise ValueError("NumberedText requires string input.")

        if start < 1:  # enforce 1 based indexing.
            raise IndexError(
                "NumberedText: Numbered lines must begin on "
                "an integer great or equal to 1."
            )

        if not content:
            return

        # Analyze the text format
        is_numbered, detected_sep, start_num = get_numbered_format(content)

        format_info = get_numbered_format(content)

        if format_info.is_numbered:
            self.start = format_info.start_num  # type: ignore
            self.separator = format_info.separator  # type: ignore

            # Extract content by removing number and separator
            pattern = re.compile(rf"^\d+{re.escape(detected_sep)}") # type: ignore
            self.lines = []

            for line in content.splitlines():
                if line.strip():
                    self.lines.append(pattern.sub("", line))
                else:
                    self.lines.append(line)
        else:
            self.lines = content.splitlines()
            self.start = start
            self.separator = separator

    @classmethod
    def from_file(cls, path: Path, **kwargs) -> "NumberedText":
        """Create a NumberedText instance from a file."""
        return cls(read_str_from_file(Path(path)), **kwargs)

    def _format_line(self, line_num: int, line: str) -> str:
        return f"{line_num}{self.separator}{line}"

    def _to_internal_index(self, idx: int) -> int:
        """return the index into the lines object in Python 0-based indexing."""
        if idx > 0:
            return idx - self.start
        elif idx < 0:  # allow negative indexing to index from end
            if abs(idx) > self.size:
                raise IndexError(f"NumberedText: negative index out of range: {idx}")
            return self.end + idx  # convert to logical positive location for reference.
        else:
            raise IndexError("NumberedText: Index cannot be zero in 1-based indexing.")

    def __str__(self) -> str:
        """Return the numbered text representation."""
        return "\n".join(
            self._format_line(i, line) for i, line in enumerate(self.lines, self.start)
        )

    def __len__(self) -> int:
        """Return the number of lines."""
        return len(self.lines)

    def __iter__(self) -> Iterator[tuple[int, str]]:
        """Iterate over (line_number, line_content) pairs."""
        return iter((i, line) for i, line in enumerate(self.lines, self.start))

    def __getitem__(self, index: int) -> str:
        """Get line content by line number (1-based indexing)."""
        return self.lines[self._to_internal_index(index)]

    def get_line(self, line_num: int) -> str:
        """Get content of specified line number."""
        return self[line_num]

    def _to_line_index(self, internal_index: int) -> int:
        return self.start + self._to_internal_index(internal_index)

    def get_numbered_line(self, line_num: int) -> str:
        """Get specified line with line number."""
        idx = self._to_line_index(line_num)
        return self._format_line(idx, self[idx])

    def get_lines(self, start: int, end: int) -> List[str]:
        """Get content of line range, not inclusive of end line."""
        return self.lines[self._to_internal_index(start) : self._to_internal_index(end)]

    def get_numbered_lines(self, start: int, end: int) -> List[str]:
        return [
            self._format_line(i + self._to_internal_index(start) + 1, line)
            for i, line in enumerate(self.get_lines(start, end))
        ]
    def get_segment(self, start: int, end: int) -> str:
        """return the segment from start line (inclusive) up to end line (exclusive)"""
        if start < self.start:
            raise IndexError(f"Start index {start} is before first line {self.start}")
        if end > len(self) + 1:
            raise IndexError(f"End index {end} is past last line {len(self)}")
        if start >= end:
            raise IndexError(f"Start index {start} must be less than end index {end}")
        return "\n".join(self.get_lines(start, end))

    def iter_segments(
        self, segment_size: int, min_segment_size: Optional[int] = None
    ) -> Iterator[LineSegment]:
        """
        Iterate over segments of the text with specified size.

        Args:
            segment_size: Number of lines per segment
            min_segment_size: Optional minimum size for final segment.
                If specified, last segment will be merged with previous one
                if it would be smaller than this size.

        Yields:
            LineSegment objects containing start and end line numbers

        Example:
            >>> text = NumberedText("line1\\nline2\\nline3\\nline4\\nline5")
            >>> for segment in text.iter_segments(2):
            ...     print(f"Lines {segment.start}-{segment.end}")
            Lines 1-3
            Lines 3-5
            Lines 5-6
        """
        iterator = self.SegmentIterator(
            len(self), segment_size, self.start, min_segment_size
        )
        return iter(iterator)

    def get_numbered_segment(self, start: int, end: int) -> str:
        return "\n".join(self.get_numbered_lines(start, end))

    def save(self, path: Path, numbered: bool = True) -> None:
        """
        Save document to file.

        Args:
            path: Output file path
            numbered: Whether to save with line numbers (default: True)
        """
        content = str(self) if numbered else "\n".join(self.lines)
        write_str_to_file(path, content)

    def append(self, text: str) -> None:
        """Append text, splitting into lines if needed."""
        self.lines.extend(text.splitlines())

    def insert(self, line_num: int, text: str) -> None:
        """Insert text at specified line number. Assumes text is not empty."""
        new_lines = text.splitlines()
        internal_idx = self._to_internal_index(line_num)
        self.lines[internal_idx:internal_idx] = new_lines

    def reset_numbering(self):
        self.start = 1

    def remove_whitespace(self) -> None:
        """Remove leading and trailing whitespace from all lines."""
        self.lines = [line.strip() for line in self.lines]

    @property
    def content(self) -> str:
        """Get original text without line numbers."""
        return "\n".join(self.lines)

    @property
    def numbered_content(self) -> str:
        """Get text with line numbers as a string. Equivalent to str(self)"""
        return str(self)

    @property
    def size(self) -> int:
        """Get the number of lines."""
        return len(self.lines)

    @property
    def numbered_lines(self) -> List[str]:
        """
        Get list of lines with line numbers included.

        Returns:
            List[str]: Lines with numbers and separator prefixed

        Examples:
            >>> doc = NumberedText("First line\\nSecond line")
            >>> doc.numbered_lines
            ['1: First line', '2: Second line']

        Note:
            - Unlike str(self), this returns a list rather than joined string
            - Maintains consistent formatting with separator
            - Useful for processing or displaying individual numbered lines
        """
        return [
            f"{i}{self.separator}{line}"
            for i, line in enumerate(self.lines, self.start)
        ]

    @property
    def end(self) -> int:
        return self.start + len(self.lines) - 1
content property

Get original text without line numbers.

end property
lines = [] instance-attribute
numbered_content property

Get text with line numbers as a string. Equivalent to str(self)

numbered_lines property

Get list of lines with line numbers included.

Returns:

Type Description
List[str]

List[str]: Lines with numbers and separator prefixed

Examples:

>>> doc = NumberedText("First line\nSecond line")
>>> doc.numbered_lines
['1: First line', '2: Second line']
Note
  • Unlike str(self), this returns a list rather than joined string
  • Maintains consistent formatting with separator
  • Useful for processing or displaying individual numbered lines
separator = separator instance-attribute
size property

Get the number of lines.

start = start instance-attribute
LineSegment dataclass

Represents a segment of lines with start and end indices in 1-based indexing.

The segment follows Python range conventions where start is inclusive and end is exclusive. However, indexing is 1-based to match NumberedText.

Attributes:

Name Type Description
start int

Starting line number (inclusive, 1-based)

end int

Ending line number (exclusive, 1-based)

Source code in src/tnh_scholar/text_processing/numbered_text.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
@dataclass
class LineSegment:
    """
    Represents a segment of lines with start and end indices in 1-based indexing.

    The segment follows Python range conventions where start is inclusive and
    end is exclusive. However, indexing is 1-based to match NumberedText.

    Attributes:
        start: Starting line number (inclusive, 1-based)
        end: Ending line number (exclusive, 1-based)
    """

    start: int
    end: int

    def __iter__(self):
        """Allow unpacking into start, end pairs."""
        yield self.start
        yield self.end
end instance-attribute
start instance-attribute
__init__(start, end)
__iter__()

Allow unpacking into start, end pairs.

Source code in src/tnh_scholar/text_processing/numbered_text.py
58
59
60
61
def __iter__(self):
    """Allow unpacking into start, end pairs."""
    yield self.start
    yield self.end
SegmentIterator

Iterator for generating line segments of specified size.

Produces segments of lines with start/end indices following 1-based indexing. The final segment may be smaller than the specified segment size.

Attributes:

Name Type Description
total_lines

Total number of lines in text

segment_size

Number of lines per segment

start_line

Starting line number (1-based)

min_segment_size

Minimum size for the final segment

Source code in src/tnh_scholar/text_processing/numbered_text.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class SegmentIterator:
    """
    Iterator for generating line segments of specified size.

    Produces segments of lines with start/end indices following 1-based indexing.
    The final segment may be smaller than the specified segment size.

    Attributes:
        total_lines: Total number of lines in text
        segment_size: Number of lines per segment
        start_line: Starting line number (1-based)
        min_segment_size: Minimum size for the final segment
    """

    def __init__(
        self,
        total_lines: int,
        segment_size: int,
        start_line: int = 1,
        min_segment_size: Optional[int] = None,
    ):
        """
        Initialize the segment iterator.

        Args:
            total_lines: Total number of lines to iterate over
            segment_size: Desired size of each segment
            start_line: First line number (default: 1)
            min_segment_size: Minimum size for final segment (default: None)
                If specified, the last segment will be merged with the previous one
                if it would be smaller than this size.

        Raises:
            ValueError: If segment_size < 1 or total_lines < 1
            ValueError: If start_line < 1 (must use 1-based indexing)
            ValueError: If min_segment_size >= segment_size
        """
        if segment_size < 1:
            raise ValueError("Segment size must be at least 1")
        if total_lines < 1:
            raise ValueError("Total lines must be at least 1")
        if start_line < 1:
            raise ValueError("Start line must be at least 1 (1-based indexing)")
        if min_segment_size is not None and min_segment_size >= segment_size:
            raise ValueError("Minimum segment size must be less than segment size")

        self.total_lines = total_lines
        self.segment_size = segment_size
        self.start_line = start_line
        self.min_segment_size = min_segment_size

        # Calculate number of segments
        remaining_lines = total_lines - start_line + 1
        self.num_segments = (remaining_lines + segment_size - 1) // segment_size

    def __iter__(self) -> Iterator["NumberedText.LineSegment"]:
        """
        Iterate over line segments.

        Yields:
            LineSegment containing start (inclusive) and end (exclusive) indices
        """
        current = self.start_line

        for i in range(self.num_segments):
            is_last_segment = i == self.num_segments - 1
            segment_end = min(current + self.segment_size, self.total_lines + 1)

            # Handle minimum segment size for last segment
            if (
                is_last_segment
                and self.min_segment_size is not None
                and segment_end - current < self.min_segment_size
                and i > 0
            ):
                # Merge with previous segment by not yielding
                break

            yield NumberedText.LineSegment(current, segment_end)
            current = segment_end
min_segment_size = min_segment_size instance-attribute
num_segments = (remaining_lines + segment_size - 1) // segment_size instance-attribute
segment_size = segment_size instance-attribute
start_line = start_line instance-attribute
total_lines = total_lines instance-attribute
__init__(total_lines, segment_size, start_line=1, min_segment_size=None)

Initialize the segment iterator.

Parameters:

Name Type Description Default
total_lines int

Total number of lines to iterate over

required
segment_size int

Desired size of each segment

required
start_line int

First line number (default: 1)

1
min_segment_size Optional[int]

Minimum size for final segment (default: None) If specified, the last segment will be merged with the previous one if it would be smaller than this size.

None

Raises:

Type Description
ValueError

If segment_size < 1 or total_lines < 1

ValueError

If start_line < 1 (must use 1-based indexing)

ValueError

If min_segment_size >= segment_size

Source code in src/tnh_scholar/text_processing/numbered_text.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def __init__(
    self,
    total_lines: int,
    segment_size: int,
    start_line: int = 1,
    min_segment_size: Optional[int] = None,
):
    """
    Initialize the segment iterator.

    Args:
        total_lines: Total number of lines to iterate over
        segment_size: Desired size of each segment
        start_line: First line number (default: 1)
        min_segment_size: Minimum size for final segment (default: None)
            If specified, the last segment will be merged with the previous one
            if it would be smaller than this size.

    Raises:
        ValueError: If segment_size < 1 or total_lines < 1
        ValueError: If start_line < 1 (must use 1-based indexing)
        ValueError: If min_segment_size >= segment_size
    """
    if segment_size < 1:
        raise ValueError("Segment size must be at least 1")
    if total_lines < 1:
        raise ValueError("Total lines must be at least 1")
    if start_line < 1:
        raise ValueError("Start line must be at least 1 (1-based indexing)")
    if min_segment_size is not None and min_segment_size >= segment_size:
        raise ValueError("Minimum segment size must be less than segment size")

    self.total_lines = total_lines
    self.segment_size = segment_size
    self.start_line = start_line
    self.min_segment_size = min_segment_size

    # Calculate number of segments
    remaining_lines = total_lines - start_line + 1
    self.num_segments = (remaining_lines + segment_size - 1) // segment_size
__iter__()

Iterate over line segments.

Yields:

Type Description
LineSegment

LineSegment containing start (inclusive) and end (exclusive) indices

Source code in src/tnh_scholar/text_processing/numbered_text.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def __iter__(self) -> Iterator["NumberedText.LineSegment"]:
    """
    Iterate over line segments.

    Yields:
        LineSegment containing start (inclusive) and end (exclusive) indices
    """
    current = self.start_line

    for i in range(self.num_segments):
        is_last_segment = i == self.num_segments - 1
        segment_end = min(current + self.segment_size, self.total_lines + 1)

        # Handle minimum segment size for last segment
        if (
            is_last_segment
            and self.min_segment_size is not None
            and segment_end - current < self.min_segment_size
            and i > 0
        ):
            # Merge with previous segment by not yielding
            break

        yield NumberedText.LineSegment(current, segment_end)
        current = segment_end
__getitem__(index)

Get line content by line number (1-based indexing).

Source code in src/tnh_scholar/text_processing/numbered_text.py
251
252
253
def __getitem__(self, index: int) -> str:
    """Get line content by line number (1-based indexing)."""
    return self.lines[self._to_internal_index(index)]
__init__(content=None, start=1, separator=':')

Initialize a numbered text document, detecting and preserving existing numbering.

Valid numbered text must have: - Sequential line numbers - Consistent separator character(s) - Every non-empty line must follow the numbering pattern

Parameters:

Name Type Description Default
content Optional[str]

Initial text content, if any

None
start int

Starting line number (used only if content isn't already numbered)

1
separator str

Separator between line numbers and content

':'

Examples:

>>> # Custom separators
>>> doc = NumberedText("1→First line\n2→Second line")
>>> doc.separator == "→"
True
>>> # Preserves starting number
>>> doc = NumberedText("5#First\n6#Second")
>>> doc.start == 5
True
>>> # Regular numbered list isn't treated as line numbers
>>> doc = NumberedText("1. First item\n2. Second item")
>>> doc.numbered_lines
['1: 1. First item', '2: 2. Second item']
Source code in src/tnh_scholar/text_processing/numbered_text.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def __init__(
    self, content: Optional[str] = None, start: int = 1, separator: str = ":"
) -> None:
    """
    Initialize a numbered text document, 
    detecting and preserving existing numbering.

    Valid numbered text must have:
    - Sequential line numbers
    - Consistent separator character(s)
    - Every non-empty line must follow the numbering pattern

    Args:
        content: Initial text content, if any
        start: Starting line number (used only if content isn't already numbered)
        separator: Separator between line numbers and content 
        (only if content isn't numbered)

    Examples:
        >>> # Custom separators
        >>> doc = NumberedText("1→First line\\n2→Second line")
        >>> doc.separator == "→"
        True

        >>> # Preserves starting number
        >>> doc = NumberedText("5#First\\n6#Second")
        >>> doc.start == 5
        True

        >>> # Regular numbered list isn't treated as line numbers
        >>> doc = NumberedText("1. First item\\n2. Second item")
        >>> doc.numbered_lines
        ['1: 1. First item', '2: 2. Second item']
    """

    self.lines: List[str] = []  # Declare lines here
    self.start: int = start  # Declare start with its type
    self.separator: str = separator  # and separator

    if not isinstance(content, str):
        raise ValueError("NumberedText requires string input.")

    if start < 1:  # enforce 1 based indexing.
        raise IndexError(
            "NumberedText: Numbered lines must begin on "
            "an integer great or equal to 1."
        )

    if not content:
        return

    # Analyze the text format
    is_numbered, detected_sep, start_num = get_numbered_format(content)

    format_info = get_numbered_format(content)

    if format_info.is_numbered:
        self.start = format_info.start_num  # type: ignore
        self.separator = format_info.separator  # type: ignore

        # Extract content by removing number and separator
        pattern = re.compile(rf"^\d+{re.escape(detected_sep)}") # type: ignore
        self.lines = []

        for line in content.splitlines():
            if line.strip():
                self.lines.append(pattern.sub("", line))
            else:
                self.lines.append(line)
    else:
        self.lines = content.splitlines()
        self.start = start
        self.separator = separator
__iter__()

Iterate over (line_number, line_content) pairs.

Source code in src/tnh_scholar/text_processing/numbered_text.py
247
248
249
def __iter__(self) -> Iterator[tuple[int, str]]:
    """Iterate over (line_number, line_content) pairs."""
    return iter((i, line) for i, line in enumerate(self.lines, self.start))
__len__()

Return the number of lines.

Source code in src/tnh_scholar/text_processing/numbered_text.py
243
244
245
def __len__(self) -> int:
    """Return the number of lines."""
    return len(self.lines)
__str__()

Return the numbered text representation.

Source code in src/tnh_scholar/text_processing/numbered_text.py
237
238
239
240
241
def __str__(self) -> str:
    """Return the numbered text representation."""
    return "\n".join(
        self._format_line(i, line) for i, line in enumerate(self.lines, self.start)
    )
append(text)

Append text, splitting into lines if needed.

Source code in src/tnh_scholar/text_processing/numbered_text.py
328
329
330
def append(self, text: str) -> None:
    """Append text, splitting into lines if needed."""
    self.lines.extend(text.splitlines())
from_file(path, **kwargs) classmethod

Create a NumberedText instance from a file.

Source code in src/tnh_scholar/text_processing/numbered_text.py
218
219
220
221
@classmethod
def from_file(cls, path: Path, **kwargs) -> "NumberedText":
    """Create a NumberedText instance from a file."""
    return cls(read_str_from_file(Path(path)), **kwargs)
get_line(line_num)

Get content of specified line number.

Source code in src/tnh_scholar/text_processing/numbered_text.py
255
256
257
def get_line(self, line_num: int) -> str:
    """Get content of specified line number."""
    return self[line_num]
get_lines(start, end)

Get content of line range, not inclusive of end line.

Source code in src/tnh_scholar/text_processing/numbered_text.py
267
268
269
def get_lines(self, start: int, end: int) -> List[str]:
    """Get content of line range, not inclusive of end line."""
    return self.lines[self._to_internal_index(start) : self._to_internal_index(end)]
get_numbered_line(line_num)

Get specified line with line number.

Source code in src/tnh_scholar/text_processing/numbered_text.py
262
263
264
265
def get_numbered_line(self, line_num: int) -> str:
    """Get specified line with line number."""
    idx = self._to_line_index(line_num)
    return self._format_line(idx, self[idx])
get_numbered_lines(start, end)
Source code in src/tnh_scholar/text_processing/numbered_text.py
271
272
273
274
275
def get_numbered_lines(self, start: int, end: int) -> List[str]:
    return [
        self._format_line(i + self._to_internal_index(start) + 1, line)
        for i, line in enumerate(self.get_lines(start, end))
    ]
get_numbered_segment(start, end)
Source code in src/tnh_scholar/text_processing/numbered_text.py
314
315
def get_numbered_segment(self, start: int, end: int) -> str:
    return "\n".join(self.get_numbered_lines(start, end))
get_segment(start, end)

return the segment from start line (inclusive) up to end line (exclusive)

Source code in src/tnh_scholar/text_processing/numbered_text.py
276
277
278
279
280
281
282
283
284
def get_segment(self, start: int, end: int) -> str:
    """return the segment from start line (inclusive) up to end line (exclusive)"""
    if start < self.start:
        raise IndexError(f"Start index {start} is before first line {self.start}")
    if end > len(self) + 1:
        raise IndexError(f"End index {end} is past last line {len(self)}")
    if start >= end:
        raise IndexError(f"Start index {start} must be less than end index {end}")
    return "\n".join(self.get_lines(start, end))
insert(line_num, text)

Insert text at specified line number. Assumes text is not empty.

Source code in src/tnh_scholar/text_processing/numbered_text.py
332
333
334
335
336
def insert(self, line_num: int, text: str) -> None:
    """Insert text at specified line number. Assumes text is not empty."""
    new_lines = text.splitlines()
    internal_idx = self._to_internal_index(line_num)
    self.lines[internal_idx:internal_idx] = new_lines
iter_segments(segment_size, min_segment_size=None)

Iterate over segments of the text with specified size.

Parameters:

Name Type Description Default
segment_size int

Number of lines per segment

required
min_segment_size Optional[int]

Optional minimum size for final segment. If specified, last segment will be merged with previous one if it would be smaller than this size.

None

Yields:

Type Description
LineSegment

LineSegment objects containing start and end line numbers

Example

text = NumberedText("line1\nline2\nline3\nline4\nline5") for segment in text.iter_segments(2): ... print(f"Lines {segment.start}-{segment.end}") Lines 1-3 Lines 3-5 Lines 5-6

Source code in src/tnh_scholar/text_processing/numbered_text.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def iter_segments(
    self, segment_size: int, min_segment_size: Optional[int] = None
) -> Iterator[LineSegment]:
    """
    Iterate over segments of the text with specified size.

    Args:
        segment_size: Number of lines per segment
        min_segment_size: Optional minimum size for final segment.
            If specified, last segment will be merged with previous one
            if it would be smaller than this size.

    Yields:
        LineSegment objects containing start and end line numbers

    Example:
        >>> text = NumberedText("line1\\nline2\\nline3\\nline4\\nline5")
        >>> for segment in text.iter_segments(2):
        ...     print(f"Lines {segment.start}-{segment.end}")
        Lines 1-3
        Lines 3-5
        Lines 5-6
    """
    iterator = self.SegmentIterator(
        len(self), segment_size, self.start, min_segment_size
    )
    return iter(iterator)
remove_whitespace()

Remove leading and trailing whitespace from all lines.

Source code in src/tnh_scholar/text_processing/numbered_text.py
341
342
343
def remove_whitespace(self) -> None:
    """Remove leading and trailing whitespace from all lines."""
    self.lines = [line.strip() for line in self.lines]
reset_numbering()
Source code in src/tnh_scholar/text_processing/numbered_text.py
338
339
def reset_numbering(self):
    self.start = 1
save(path, numbered=True)

Save document to file.

Parameters:

Name Type Description Default
path Path

Output file path

required
numbered bool

Whether to save with line numbers (default: True)

True
Source code in src/tnh_scholar/text_processing/numbered_text.py
317
318
319
320
321
322
323
324
325
326
def save(self, path: Path, numbered: bool = True) -> None:
    """
    Save document to file.

    Args:
        path: Output file path
        numbered: Whether to save with line numbers (default: True)
    """
    content = str(self) if numbered else "\n".join(self.lines)
    write_str_to_file(path, content)

bracket_lines(text, number=False)

Encloses each line of the input text with angle brackets.
If number is True, adds a line number followed by a colon `:` and then the line.

Args:
    text (str): The input string containing lines separated by '

'. number (bool): Whether to prepend line numbers to each line.

Returns:
    str: A string where each line is enclosed in angle brackets.

Examples:
    >>> bracket_lines("This is a string with

two lines.") ' < two lines.>'

    >>> bracket_lines("This is a string with

two lines.", number=True) '<1:This is a string with> <2: two lines.>'

Source code in src/tnh_scholar/text_processing/bracket.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def bracket_lines(text: str, number: bool = False) -> str:
    """
    Encloses each line of the input text with angle brackets.
    If number is True, adds a line number followed by a colon `:` and then the line.

    Args:
        text (str): The input string containing lines separated by '\n'.
        number (bool): Whether to prepend line numbers to each line.

    Returns:
        str: A string where each line is enclosed in angle brackets.

    Examples:
        >>> bracket_lines("This is a string with\n   two lines.")
        '<This is a string with>\n<   two lines.>'

        >>> bracket_lines("This is a string with\n   two lines.", number=True)
        '<1:This is a string with>\n<2:   two lines.>'
    """
    return "\n".join(
        f"<{f'{i+1}:{line}' if number else line}>"
        for i, line in enumerate(text.split("\n"))
    )

clean_text(text, newline=False)

Cleans a given text by replacing specific unwanted characters such as tab, and non-breaking spaces with regular spaces.

This function takes a string as input and applies replacements based on a predefined mapping of characters to replace.

Parameters:

Name Type Description Default
text str

The text to be cleaned.

required

Returns:

Name Type Description
str

The cleaned text with unwanted characters replaced by spaces.

Example

text = "This is\n an example\ttext with\xa0extra spaces." clean_text(text) 'This is an example text with extra spaces.'

Source code in src/tnh_scholar/text_processing/text_processing.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def clean_text(text, newline=False):
    """
    Cleans a given text by replacing specific unwanted characters such as
    tab, and non-breaking spaces with regular spaces.

    This function takes a string as input and applies replacements
    based on a predefined mapping of characters to replace.

    Args:
        text (str): The text to be cleaned.

    Returns:
        str: The cleaned text with unwanted characters replaced by spaces.

    Example:
        >>> text = "This is\\n an example\\ttext with\\xa0extra spaces."
        >>> clean_text(text)
        'This is an example text with extra spaces.'

    """
    # Define a mapping of characters to replace
    replace_map = {
        "\t": " ",  # Replace tabs with space
        "\xa0": " ",  # Replace non-breaking space with regular space
        # Add more replacements as needed
    }

    if newline:
        replace_map["\n"] = ""  # remove newlines

    # Loop through the replace map and replace each character
    for old_char, new_char in replace_map.items():
        text = text.replace(old_char, new_char)

    return text.strip()  # Ensure any leading/trailing spaces are removed

lines_from_bracketed_text(text, start, end, keep_brackets=False)

Extracts lines from bracketed text between the start and end indices, inclusive.
Handles both numbered and non-numbered cases.

Args:
    text (str): The input bracketed text containing lines like <...>.
    start (int): The starting line number (1-based).
    end (int): The ending line number (1-based).

Returns:
    list[str]: The lines from start to end inclusive, with angle brackets removed.

Raises:
    FormattingError: If the text contains improperly formatted lines (missing angle brackets).
    ValueError: If start or end indices are invalid or out of bounds.

Examples:
    >>> text = "<1:Line 1>

<2:Line 2> <3:Line 3>" >>> lines_from_bracketed_text(text, 1, 2) ['Line 1', 'Line 2']

    >>> text = "<Line 1>

" >>> lines_from_bracketed_text(text, 2, 3) ['Line 2', 'Line 3']

Source code in src/tnh_scholar/text_processing/bracket.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def lines_from_bracketed_text(
    text: str, start: int, end: int, keep_brackets=False
) -> list[str]:
    """
    Extracts lines from bracketed text between the start and end indices, inclusive.
    Handles both numbered and non-numbered cases.

    Args:
        text (str): The input bracketed text containing lines like <...>.
        start (int): The starting line number (1-based).
        end (int): The ending line number (1-based).

    Returns:
        list[str]: The lines from start to end inclusive, with angle brackets removed.

    Raises:
        FormattingError: If the text contains improperly formatted lines (missing angle brackets).
        ValueError: If start or end indices are invalid or out of bounds.

    Examples:
        >>> text = "<1:Line 1>\n<2:Line 2>\n<3:Line 3>"
        >>> lines_from_bracketed_text(text, 1, 2)
        ['Line 1', 'Line 2']

        >>> text = "<Line 1>\n<Line 2>\n<Line 3>"
        >>> lines_from_bracketed_text(text, 2, 3)
        ['Line 2', 'Line 3']
    """
    # Split the text into lines
    lines = text.splitlines()

    # Validate indices
    if start < 1 or end < 1 or start > end or end > len(lines):
        raise ValueError(
            "Invalid start or end indices for the given text: start:{start}, end: {end}"
        )

    # Extract lines and validate formatting
    result = []
    for i, line in enumerate(lines, start=1):
        if start <= i <= end:
            # Check for proper bracketing and extract the content
            match = re.match(r"<(\d+:)?(.*?)>", line)
            if not match:
                raise FormattingError(f"Invalid format for line {i}: '{line}'")
            # Add the extracted content (group 2) to the result
            if keep_brackets:
                result.append(line)
            else:
                result.append(match[2].strip())

    return "\n".join(result)

normalize_newlines(text, spacing=2)

Normalize newline blocks in the input text by reducing consecutive newlines
to the specified number of newlines for consistent readability and formatting.

Parameters:
----------
text : str
    The input text containing inconsistent newline spacing.
spacing : int, optional
    The number of newlines to insert between lines. Defaults to 2.

Returns:
-------
str
    The text with consecutive newlines reduced to the specified number of newlines.

Example:
--------
>>> raw_text = "Heading

Paragraph text 1 Paragraph text 2

" >>> normalize_newlines(raw_text, spacing=2) 'Heading

Paragraph text 1

Paragraph text 2

'

Source code in src/tnh_scholar/text_processing/text_processing.py
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def normalize_newlines(text: str, spacing: int = 2) -> str:
    """
    Normalize newline blocks in the input text by reducing consecutive newlines
    to the specified number of newlines for consistent readability and formatting.

    Parameters:
    ----------
    text : str
        The input text containing inconsistent newline spacing.
    spacing : int, optional
        The number of newlines to insert between lines. Defaults to 2.

    Returns:
    -------
    str
        The text with consecutive newlines reduced to the specified number of newlines.

    Example:
    --------
    >>> raw_text = "Heading\n\n\nParagraph text 1\nParagraph text 2\n\n\n"
    >>> normalize_newlines(raw_text, spacing=2)
    'Heading\n\nParagraph text 1\n\nParagraph text 2\n\n'
    """
    # Replace one or more newlines with the desired number of newlines
    newlines = "\n" * spacing
    return re.sub(r"\n{1,}", newlines, text)

unbracket_lines(text, number=False)

Removes angle brackets (< >) from encapsulated lines and optionally removes line numbers.

Args:
    text (str): The input string with encapsulated lines.
    number (bool): If True, removes line numbers in the format 'digit:'.
                   Raises a ValueError if `number=True` and a line does not start with a digit followed by a colon.

Returns:
    str: A newline-separated string with the encapsulation removed, and line numbers stripped if specified.

Examples:
    >>> unbracket_lines("<1:Line 1>

<2:Line 2>", number=True) 'Line 1 Line 2'

    >>> unbracket_lines("<Line 1>

") 'Line 1 Line 2'

    >>> unbracket_lines("<1Line 1>", number=True)
    ValueError: Line does not start with a valid number: '1Line 1'
Source code in src/tnh_scholar/text_processing/bracket.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def unbracket_lines(text: str, number: bool = False) -> str:
    """
    Removes angle brackets (< >) from encapsulated lines and optionally removes line numbers.

    Args:
        text (str): The input string with encapsulated lines.
        number (bool): If True, removes line numbers in the format 'digit:'.
                       Raises a ValueError if `number=True` and a line does not start with a digit followed by a colon.

    Returns:
        str: A newline-separated string with the encapsulation removed, and line numbers stripped if specified.

    Examples:
        >>> unbracket_lines("<1:Line 1>\n<2:Line 2>", number=True)
        'Line 1\nLine 2'

        >>> unbracket_lines("<Line 1>\n<Line 2>")
        'Line 1\nLine 2'

        >>> unbracket_lines("<1Line 1>", number=True)
        ValueError: Line does not start with a valid number: '1Line 1'
    """
    unbracketed_lines = []

    for line in text.splitlines():
        match = (
            re.match(r"<(\d+):(.*?)>", line) if number else re.match(r"<(.*?)>", line)
        )
        if match:
            content = match[2].strip() if number else match[1].strip()
            unbracketed_lines.append(content)
        elif number:
            raise FormattingError(f"Line does not start with a valid number: '{line}'")
        else:
            raise FormattingError(f"Line does not follow the expected format: '{line}'")

    return "\n".join(unbracketed_lines)

bracket

FormattingError

Bases: Exception

Custom exception raised for formatting-related errors.

Source code in src/tnh_scholar/text_processing/bracket.py
 5
 6
 7
 8
 9
10
11
class FormattingError(Exception):
    """
    Custom exception raised for formatting-related errors.
    """

    def __init__(self, message="An error occurred due to invalid formatting."):
        super().__init__(message)
__init__(message='An error occurred due to invalid formatting.')
Source code in src/tnh_scholar/text_processing/bracket.py
10
11
def __init__(self, message="An error occurred due to invalid formatting."):
    super().__init__(message)
bracket_all_lines(pages)
Source code in src/tnh_scholar/text_processing/bracket.py
78
79
def bracket_all_lines(pages):
    return [bracket_lines(page) for page in pages]
bracket_lines(text, number=False)
Encloses each line of the input text with angle brackets.
If number is True, adds a line number followed by a colon `:` and then the line.

Args:
    text (str): The input string containing lines separated by '

'. number (bool): Whether to prepend line numbers to each line.

Returns:
    str: A string where each line is enclosed in angle brackets.

Examples:
    >>> bracket_lines("This is a string with

two lines.") ' < two lines.>'

    >>> bracket_lines("This is a string with

two lines.", number=True) '<1:This is a string with> <2: two lines.>'

Source code in src/tnh_scholar/text_processing/bracket.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def bracket_lines(text: str, number: bool = False) -> str:
    """
    Encloses each line of the input text with angle brackets.
    If number is True, adds a line number followed by a colon `:` and then the line.

    Args:
        text (str): The input string containing lines separated by '\n'.
        number (bool): Whether to prepend line numbers to each line.

    Returns:
        str: A string where each line is enclosed in angle brackets.

    Examples:
        >>> bracket_lines("This is a string with\n   two lines.")
        '<This is a string with>\n<   two lines.>'

        >>> bracket_lines("This is a string with\n   two lines.", number=True)
        '<1:This is a string with>\n<2:   two lines.>'
    """
    return "\n".join(
        f"<{f'{i+1}:{line}' if number else line}>"
        for i, line in enumerate(text.split("\n"))
    )
lines_from_bracketed_text(text, start, end, keep_brackets=False)
Extracts lines from bracketed text between the start and end indices, inclusive.
Handles both numbered and non-numbered cases.

Args:
    text (str): The input bracketed text containing lines like <...>.
    start (int): The starting line number (1-based).
    end (int): The ending line number (1-based).

Returns:
    list[str]: The lines from start to end inclusive, with angle brackets removed.

Raises:
    FormattingError: If the text contains improperly formatted lines (missing angle brackets).
    ValueError: If start or end indices are invalid or out of bounds.

Examples:
    >>> text = "<1:Line 1>

<2:Line 2> <3:Line 3>" >>> lines_from_bracketed_text(text, 1, 2) ['Line 1', 'Line 2']

    >>> text = "<Line 1>

" >>> lines_from_bracketed_text(text, 2, 3) ['Line 2', 'Line 3']

Source code in src/tnh_scholar/text_processing/bracket.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def lines_from_bracketed_text(
    text: str, start: int, end: int, keep_brackets=False
) -> list[str]:
    """
    Extracts lines from bracketed text between the start and end indices, inclusive.
    Handles both numbered and non-numbered cases.

    Args:
        text (str): The input bracketed text containing lines like <...>.
        start (int): The starting line number (1-based).
        end (int): The ending line number (1-based).

    Returns:
        list[str]: The lines from start to end inclusive, with angle brackets removed.

    Raises:
        FormattingError: If the text contains improperly formatted lines (missing angle brackets).
        ValueError: If start or end indices are invalid or out of bounds.

    Examples:
        >>> text = "<1:Line 1>\n<2:Line 2>\n<3:Line 3>"
        >>> lines_from_bracketed_text(text, 1, 2)
        ['Line 1', 'Line 2']

        >>> text = "<Line 1>\n<Line 2>\n<Line 3>"
        >>> lines_from_bracketed_text(text, 2, 3)
        ['Line 2', 'Line 3']
    """
    # Split the text into lines
    lines = text.splitlines()

    # Validate indices
    if start < 1 or end < 1 or start > end or end > len(lines):
        raise ValueError(
            "Invalid start or end indices for the given text: start:{start}, end: {end}"
        )

    # Extract lines and validate formatting
    result = []
    for i, line in enumerate(lines, start=1):
        if start <= i <= end:
            # Check for proper bracketing and extract the content
            match = re.match(r"<(\d+:)?(.*?)>", line)
            if not match:
                raise FormattingError(f"Invalid format for line {i}: '{line}'")
            # Add the extracted content (group 2) to the result
            if keep_brackets:
                result.append(line)
            else:
                result.append(match[2].strip())

    return "\n".join(result)
number_lines(text, start=1, separator=': ')

Numbers each line of text with a readable format, including empty lines.

Parameters:

Name Type Description Default
text str

Input text to be numbered. Can be multi-line.

required
start int

Starting line number. Defaults to 1.

1
separator str

Separator between line number and content. Defaults to ": ".

': '

Returns:

Name Type Description
str str

Numbered text where each line starts with "{number}: ".

Examples:

>>> text = "First line\nSecond line\n\nFourth line"
>>> print(number_lines(text))
1: First line
2: Second line
3:
4: Fourth line
>>> print(number_lines(text, start=5, separator=" | "))
5 | First line
6 | Second line
7 |
8 | Fourth line
Notes
  • All lines are numbered, including empty lines, to maintain text structure
  • Line numbers are aligned through natural string formatting
  • Customizable separator allows for different formatting needs
  • Can start from any line number for flexibility in text processing
Source code in src/tnh_scholar/text_processing/bracket.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def number_lines(text: str, start: int = 1, separator: str = ": ") -> str:
    """
    Numbers each line of text with a readable format, including empty lines.

    Args:
        text (str): Input text to be numbered. Can be multi-line.
        start (int, optional): Starting line number. Defaults to 1.
        separator (str, optional): Separator between line number and content.
            Defaults to ": ".

    Returns:
        str: Numbered text where each line starts with "{number}: ".

    Examples:
        >>> text = "First line\\nSecond line\\n\\nFourth line"
        >>> print(number_lines(text))
        1: First line
        2: Second line
        3:
        4: Fourth line

        >>> print(number_lines(text, start=5, separator=" | "))
        5 | First line
        6 | Second line
        7 |
        8 | Fourth line

    Notes:
        - All lines are numbered, including empty lines, to maintain text structure
        - Line numbers are aligned through natural string formatting
        - Customizable separator allows for different formatting needs
        - Can start from any line number for flexibility in text processing
    """
    lines = text.splitlines()
    return "\n".join(f"{i}{separator}{line}" for i, line in enumerate(lines, start))
unbracket_all_lines(pages)
Source code in src/tnh_scholar/text_processing/bracket.py
121
122
123
124
125
126
127
128
def unbracket_all_lines(pages):
    result = []
    for page in pages:
        if page == "blank page":
            result.append(page)
        else:
            result.append(unbracket_lines(page))
    return result
unbracket_lines(text, number=False)
Removes angle brackets (< >) from encapsulated lines and optionally removes line numbers.

Args:
    text (str): The input string with encapsulated lines.
    number (bool): If True, removes line numbers in the format 'digit:'.
                   Raises a ValueError if `number=True` and a line does not start with a digit followed by a colon.

Returns:
    str: A newline-separated string with the encapsulation removed, and line numbers stripped if specified.

Examples:
    >>> unbracket_lines("<1:Line 1>

<2:Line 2>", number=True) 'Line 1 Line 2'

    >>> unbracket_lines("<Line 1>

") 'Line 1 Line 2'

    >>> unbracket_lines("<1Line 1>", number=True)
    ValueError: Line does not start with a valid number: '1Line 1'
Source code in src/tnh_scholar/text_processing/bracket.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def unbracket_lines(text: str, number: bool = False) -> str:
    """
    Removes angle brackets (< >) from encapsulated lines and optionally removes line numbers.

    Args:
        text (str): The input string with encapsulated lines.
        number (bool): If True, removes line numbers in the format 'digit:'.
                       Raises a ValueError if `number=True` and a line does not start with a digit followed by a colon.

    Returns:
        str: A newline-separated string with the encapsulation removed, and line numbers stripped if specified.

    Examples:
        >>> unbracket_lines("<1:Line 1>\n<2:Line 2>", number=True)
        'Line 1\nLine 2'

        >>> unbracket_lines("<Line 1>\n<Line 2>")
        'Line 1\nLine 2'

        >>> unbracket_lines("<1Line 1>", number=True)
        ValueError: Line does not start with a valid number: '1Line 1'
    """
    unbracketed_lines = []

    for line in text.splitlines():
        match = (
            re.match(r"<(\d+):(.*?)>", line) if number else re.match(r"<(.*?)>", line)
        )
        if match:
            content = match[2].strip() if number else match[1].strip()
            unbracketed_lines.append(content)
        elif number:
            raise FormattingError(f"Line does not start with a valid number: '{line}'")
        else:
            raise FormattingError(f"Line does not follow the expected format: '{line}'")

    return "\n".join(unbracketed_lines)

numbered_text

NumberedFormat

Bases: NamedTuple

Source code in src/tnh_scholar/text_processing/numbered_text.py
 9
10
11
12
class NumberedFormat(NamedTuple):
    is_numbered: bool
    separator: Optional[str] = None
    start_num: Optional[int] = None
is_numbered instance-attribute
separator = None class-attribute instance-attribute
start_num = None class-attribute instance-attribute
NumberedText

Represents a text document with numbered lines for easy reference and manipulation.

Provides utilities for working with line-numbered text including reading, writing, accessing lines by number, and iterating over numbered lines.

Attributes:

Name Type Description
lines List[str]

List of text lines

start int

Starting line number (default: 1)

separator str

Separator between line number and content (default: ": ")

Examples:

>>> text = "First line\nSecond line\n\nFourth line"
>>> doc = NumberedText(text)
>>> print(doc)
1: First line
2: Second line
3:
4: Fourth line
>>> print(doc.get_line(2))
Second line
>>> for num, line in doc:
...     print(f"Line {num}: {len(line)} chars")
Source code in src/tnh_scholar/text_processing/numbered_text.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
class NumberedText:
    """
    Represents a text document with numbered lines for easy reference and manipulation.

    Provides utilities for working with line-numbered text including reading,
    writing, accessing lines by number, and iterating over numbered lines.

    Attributes:
        lines (List[str]): List of text lines
        start (int): Starting line number (default: 1)
        separator (str): Separator between line number and content (default: ": ")

    Examples:
        >>> text = "First line\\nSecond line\\n\\nFourth line"
        >>> doc = NumberedText(text)
        >>> print(doc)
        1: First line
        2: Second line
        3:
        4: Fourth line

        >>> print(doc.get_line(2))
        Second line

        >>> for num, line in doc:
        ...     print(f"Line {num}: {len(line)} chars")
    """

    @dataclass
    class LineSegment:
        """
        Represents a segment of lines with start and end indices in 1-based indexing.

        The segment follows Python range conventions where start is inclusive and
        end is exclusive. However, indexing is 1-based to match NumberedText.

        Attributes:
            start: Starting line number (inclusive, 1-based)
            end: Ending line number (exclusive, 1-based)
        """

        start: int
        end: int

        def __iter__(self):
            """Allow unpacking into start, end pairs."""
            yield self.start
            yield self.end

    class SegmentIterator:
        """
        Iterator for generating line segments of specified size.

        Produces segments of lines with start/end indices following 1-based indexing.
        The final segment may be smaller than the specified segment size.

        Attributes:
            total_lines: Total number of lines in text
            segment_size: Number of lines per segment
            start_line: Starting line number (1-based)
            min_segment_size: Minimum size for the final segment
        """

        def __init__(
            self,
            total_lines: int,
            segment_size: int,
            start_line: int = 1,
            min_segment_size: Optional[int] = None,
        ):
            """
            Initialize the segment iterator.

            Args:
                total_lines: Total number of lines to iterate over
                segment_size: Desired size of each segment
                start_line: First line number (default: 1)
                min_segment_size: Minimum size for final segment (default: None)
                    If specified, the last segment will be merged with the previous one
                    if it would be smaller than this size.

            Raises:
                ValueError: If segment_size < 1 or total_lines < 1
                ValueError: If start_line < 1 (must use 1-based indexing)
                ValueError: If min_segment_size >= segment_size
            """
            if segment_size < 1:
                raise ValueError("Segment size must be at least 1")
            if total_lines < 1:
                raise ValueError("Total lines must be at least 1")
            if start_line < 1:
                raise ValueError("Start line must be at least 1 (1-based indexing)")
            if min_segment_size is not None and min_segment_size >= segment_size:
                raise ValueError("Minimum segment size must be less than segment size")

            self.total_lines = total_lines
            self.segment_size = segment_size
            self.start_line = start_line
            self.min_segment_size = min_segment_size

            # Calculate number of segments
            remaining_lines = total_lines - start_line + 1
            self.num_segments = (remaining_lines + segment_size - 1) // segment_size

        def __iter__(self) -> Iterator["NumberedText.LineSegment"]:
            """
            Iterate over line segments.

            Yields:
                LineSegment containing start (inclusive) and end (exclusive) indices
            """
            current = self.start_line

            for i in range(self.num_segments):
                is_last_segment = i == self.num_segments - 1
                segment_end = min(current + self.segment_size, self.total_lines + 1)

                # Handle minimum segment size for last segment
                if (
                    is_last_segment
                    and self.min_segment_size is not None
                    and segment_end - current < self.min_segment_size
                    and i > 0
                ):
                    # Merge with previous segment by not yielding
                    break

                yield NumberedText.LineSegment(current, segment_end)
                current = segment_end

    def __init__(
        self, content: Optional[str] = None, start: int = 1, separator: str = ":"
    ) -> None:
        """
        Initialize a numbered text document, 
        detecting and preserving existing numbering.

        Valid numbered text must have:
        - Sequential line numbers
        - Consistent separator character(s)
        - Every non-empty line must follow the numbering pattern

        Args:
            content: Initial text content, if any
            start: Starting line number (used only if content isn't already numbered)
            separator: Separator between line numbers and content 
            (only if content isn't numbered)

        Examples:
            >>> # Custom separators
            >>> doc = NumberedText("1→First line\\n2→Second line")
            >>> doc.separator == "→"
            True

            >>> # Preserves starting number
            >>> doc = NumberedText("5#First\\n6#Second")
            >>> doc.start == 5
            True

            >>> # Regular numbered list isn't treated as line numbers
            >>> doc = NumberedText("1. First item\\n2. Second item")
            >>> doc.numbered_lines
            ['1: 1. First item', '2: 2. Second item']
        """

        self.lines: List[str] = []  # Declare lines here
        self.start: int = start  # Declare start with its type
        self.separator: str = separator  # and separator

        if not isinstance(content, str):
            raise ValueError("NumberedText requires string input.")

        if start < 1:  # enforce 1 based indexing.
            raise IndexError(
                "NumberedText: Numbered lines must begin on "
                "an integer great or equal to 1."
            )

        if not content:
            return

        # Analyze the text format
        is_numbered, detected_sep, start_num = get_numbered_format(content)

        format_info = get_numbered_format(content)

        if format_info.is_numbered:
            self.start = format_info.start_num  # type: ignore
            self.separator = format_info.separator  # type: ignore

            # Extract content by removing number and separator
            pattern = re.compile(rf"^\d+{re.escape(detected_sep)}") # type: ignore
            self.lines = []

            for line in content.splitlines():
                if line.strip():
                    self.lines.append(pattern.sub("", line))
                else:
                    self.lines.append(line)
        else:
            self.lines = content.splitlines()
            self.start = start
            self.separator = separator

    @classmethod
    def from_file(cls, path: Path, **kwargs) -> "NumberedText":
        """Create a NumberedText instance from a file."""
        return cls(read_str_from_file(Path(path)), **kwargs)

    def _format_line(self, line_num: int, line: str) -> str:
        return f"{line_num}{self.separator}{line}"

    def _to_internal_index(self, idx: int) -> int:
        """return the index into the lines object in Python 0-based indexing."""
        if idx > 0:
            return idx - self.start
        elif idx < 0:  # allow negative indexing to index from end
            if abs(idx) > self.size:
                raise IndexError(f"NumberedText: negative index out of range: {idx}")
            return self.end + idx  # convert to logical positive location for reference.
        else:
            raise IndexError("NumberedText: Index cannot be zero in 1-based indexing.")

    def __str__(self) -> str:
        """Return the numbered text representation."""
        return "\n".join(
            self._format_line(i, line) for i, line in enumerate(self.lines, self.start)
        )

    def __len__(self) -> int:
        """Return the number of lines."""
        return len(self.lines)

    def __iter__(self) -> Iterator[tuple[int, str]]:
        """Iterate over (line_number, line_content) pairs."""
        return iter((i, line) for i, line in enumerate(self.lines, self.start))

    def __getitem__(self, index: int) -> str:
        """Get line content by line number (1-based indexing)."""
        return self.lines[self._to_internal_index(index)]

    def get_line(self, line_num: int) -> str:
        """Get content of specified line number."""
        return self[line_num]

    def _to_line_index(self, internal_index: int) -> int:
        return self.start + self._to_internal_index(internal_index)

    def get_numbered_line(self, line_num: int) -> str:
        """Get specified line with line number."""
        idx = self._to_line_index(line_num)
        return self._format_line(idx, self[idx])

    def get_lines(self, start: int, end: int) -> List[str]:
        """Get content of line range, not inclusive of end line."""
        return self.lines[self._to_internal_index(start) : self._to_internal_index(end)]

    def get_numbered_lines(self, start: int, end: int) -> List[str]:
        return [
            self._format_line(i + self._to_internal_index(start) + 1, line)
            for i, line in enumerate(self.get_lines(start, end))
        ]
    def get_segment(self, start: int, end: int) -> str:
        """return the segment from start line (inclusive) up to end line (exclusive)"""
        if start < self.start:
            raise IndexError(f"Start index {start} is before first line {self.start}")
        if end > len(self) + 1:
            raise IndexError(f"End index {end} is past last line {len(self)}")
        if start >= end:
            raise IndexError(f"Start index {start} must be less than end index {end}")
        return "\n".join(self.get_lines(start, end))

    def iter_segments(
        self, segment_size: int, min_segment_size: Optional[int] = None
    ) -> Iterator[LineSegment]:
        """
        Iterate over segments of the text with specified size.

        Args:
            segment_size: Number of lines per segment
            min_segment_size: Optional minimum size for final segment.
                If specified, last segment will be merged with previous one
                if it would be smaller than this size.

        Yields:
            LineSegment objects containing start and end line numbers

        Example:
            >>> text = NumberedText("line1\\nline2\\nline3\\nline4\\nline5")
            >>> for segment in text.iter_segments(2):
            ...     print(f"Lines {segment.start}-{segment.end}")
            Lines 1-3
            Lines 3-5
            Lines 5-6
        """
        iterator = self.SegmentIterator(
            len(self), segment_size, self.start, min_segment_size
        )
        return iter(iterator)

    def get_numbered_segment(self, start: int, end: int) -> str:
        return "\n".join(self.get_numbered_lines(start, end))

    def save(self, path: Path, numbered: bool = True) -> None:
        """
        Save document to file.

        Args:
            path: Output file path
            numbered: Whether to save with line numbers (default: True)
        """
        content = str(self) if numbered else "\n".join(self.lines)
        write_str_to_file(path, content)

    def append(self, text: str) -> None:
        """Append text, splitting into lines if needed."""
        self.lines.extend(text.splitlines())

    def insert(self, line_num: int, text: str) -> None:
        """Insert text at specified line number. Assumes text is not empty."""
        new_lines = text.splitlines()
        internal_idx = self._to_internal_index(line_num)
        self.lines[internal_idx:internal_idx] = new_lines

    def reset_numbering(self):
        self.start = 1

    def remove_whitespace(self) -> None:
        """Remove leading and trailing whitespace from all lines."""
        self.lines = [line.strip() for line in self.lines]

    @property
    def content(self) -> str:
        """Get original text without line numbers."""
        return "\n".join(self.lines)

    @property
    def numbered_content(self) -> str:
        """Get text with line numbers as a string. Equivalent to str(self)"""
        return str(self)

    @property
    def size(self) -> int:
        """Get the number of lines."""
        return len(self.lines)

    @property
    def numbered_lines(self) -> List[str]:
        """
        Get list of lines with line numbers included.

        Returns:
            List[str]: Lines with numbers and separator prefixed

        Examples:
            >>> doc = NumberedText("First line\\nSecond line")
            >>> doc.numbered_lines
            ['1: First line', '2: Second line']

        Note:
            - Unlike str(self), this returns a list rather than joined string
            - Maintains consistent formatting with separator
            - Useful for processing or displaying individual numbered lines
        """
        return [
            f"{i}{self.separator}{line}"
            for i, line in enumerate(self.lines, self.start)
        ]

    @property
    def end(self) -> int:
        return self.start + len(self.lines) - 1
content property

Get original text without line numbers.

end property
lines = [] instance-attribute
numbered_content property

Get text with line numbers as a string. Equivalent to str(self)

numbered_lines property

Get list of lines with line numbers included.

Returns:

Type Description
List[str]

List[str]: Lines with numbers and separator prefixed

Examples:

>>> doc = NumberedText("First line\nSecond line")
>>> doc.numbered_lines
['1: First line', '2: Second line']
Note
  • Unlike str(self), this returns a list rather than joined string
  • Maintains consistent formatting with separator
  • Useful for processing or displaying individual numbered lines
separator = separator instance-attribute
size property

Get the number of lines.

start = start instance-attribute
LineSegment dataclass

Represents a segment of lines with start and end indices in 1-based indexing.

The segment follows Python range conventions where start is inclusive and end is exclusive. However, indexing is 1-based to match NumberedText.

Attributes:

Name Type Description
start int

Starting line number (inclusive, 1-based)

end int

Ending line number (exclusive, 1-based)

Source code in src/tnh_scholar/text_processing/numbered_text.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
@dataclass
class LineSegment:
    """
    Represents a segment of lines with start and end indices in 1-based indexing.

    The segment follows Python range conventions where start is inclusive and
    end is exclusive. However, indexing is 1-based to match NumberedText.

    Attributes:
        start: Starting line number (inclusive, 1-based)
        end: Ending line number (exclusive, 1-based)
    """

    start: int
    end: int

    def __iter__(self):
        """Allow unpacking into start, end pairs."""
        yield self.start
        yield self.end
end instance-attribute
start instance-attribute
__init__(start, end)
__iter__()

Allow unpacking into start, end pairs.

Source code in src/tnh_scholar/text_processing/numbered_text.py
58
59
60
61
def __iter__(self):
    """Allow unpacking into start, end pairs."""
    yield self.start
    yield self.end
SegmentIterator

Iterator for generating line segments of specified size.

Produces segments of lines with start/end indices following 1-based indexing. The final segment may be smaller than the specified segment size.

Attributes:

Name Type Description
total_lines

Total number of lines in text

segment_size

Number of lines per segment

start_line

Starting line number (1-based)

min_segment_size

Minimum size for the final segment

Source code in src/tnh_scholar/text_processing/numbered_text.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class SegmentIterator:
    """
    Iterator for generating line segments of specified size.

    Produces segments of lines with start/end indices following 1-based indexing.
    The final segment may be smaller than the specified segment size.

    Attributes:
        total_lines: Total number of lines in text
        segment_size: Number of lines per segment
        start_line: Starting line number (1-based)
        min_segment_size: Minimum size for the final segment
    """

    def __init__(
        self,
        total_lines: int,
        segment_size: int,
        start_line: int = 1,
        min_segment_size: Optional[int] = None,
    ):
        """
        Initialize the segment iterator.

        Args:
            total_lines: Total number of lines to iterate over
            segment_size: Desired size of each segment
            start_line: First line number (default: 1)
            min_segment_size: Minimum size for final segment (default: None)
                If specified, the last segment will be merged with the previous one
                if it would be smaller than this size.

        Raises:
            ValueError: If segment_size < 1 or total_lines < 1
            ValueError: If start_line < 1 (must use 1-based indexing)
            ValueError: If min_segment_size >= segment_size
        """
        if segment_size < 1:
            raise ValueError("Segment size must be at least 1")
        if total_lines < 1:
            raise ValueError("Total lines must be at least 1")
        if start_line < 1:
            raise ValueError("Start line must be at least 1 (1-based indexing)")
        if min_segment_size is not None and min_segment_size >= segment_size:
            raise ValueError("Minimum segment size must be less than segment size")

        self.total_lines = total_lines
        self.segment_size = segment_size
        self.start_line = start_line
        self.min_segment_size = min_segment_size

        # Calculate number of segments
        remaining_lines = total_lines - start_line + 1
        self.num_segments = (remaining_lines + segment_size - 1) // segment_size

    def __iter__(self) -> Iterator["NumberedText.LineSegment"]:
        """
        Iterate over line segments.

        Yields:
            LineSegment containing start (inclusive) and end (exclusive) indices
        """
        current = self.start_line

        for i in range(self.num_segments):
            is_last_segment = i == self.num_segments - 1
            segment_end = min(current + self.segment_size, self.total_lines + 1)

            # Handle minimum segment size for last segment
            if (
                is_last_segment
                and self.min_segment_size is not None
                and segment_end - current < self.min_segment_size
                and i > 0
            ):
                # Merge with previous segment by not yielding
                break

            yield NumberedText.LineSegment(current, segment_end)
            current = segment_end
min_segment_size = min_segment_size instance-attribute
num_segments = (remaining_lines + segment_size - 1) // segment_size instance-attribute
segment_size = segment_size instance-attribute
start_line = start_line instance-attribute
total_lines = total_lines instance-attribute
__init__(total_lines, segment_size, start_line=1, min_segment_size=None)

Initialize the segment iterator.

Parameters:

Name Type Description Default
total_lines int

Total number of lines to iterate over

required
segment_size int

Desired size of each segment

required
start_line int

First line number (default: 1)

1
min_segment_size Optional[int]

Minimum size for final segment (default: None) If specified, the last segment will be merged with the previous one if it would be smaller than this size.

None

Raises:

Type Description
ValueError

If segment_size < 1 or total_lines < 1

ValueError

If start_line < 1 (must use 1-based indexing)

ValueError

If min_segment_size >= segment_size

Source code in src/tnh_scholar/text_processing/numbered_text.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def __init__(
    self,
    total_lines: int,
    segment_size: int,
    start_line: int = 1,
    min_segment_size: Optional[int] = None,
):
    """
    Initialize the segment iterator.

    Args:
        total_lines: Total number of lines to iterate over
        segment_size: Desired size of each segment
        start_line: First line number (default: 1)
        min_segment_size: Minimum size for final segment (default: None)
            If specified, the last segment will be merged with the previous one
            if it would be smaller than this size.

    Raises:
        ValueError: If segment_size < 1 or total_lines < 1
        ValueError: If start_line < 1 (must use 1-based indexing)
        ValueError: If min_segment_size >= segment_size
    """
    if segment_size < 1:
        raise ValueError("Segment size must be at least 1")
    if total_lines < 1:
        raise ValueError("Total lines must be at least 1")
    if start_line < 1:
        raise ValueError("Start line must be at least 1 (1-based indexing)")
    if min_segment_size is not None and min_segment_size >= segment_size:
        raise ValueError("Minimum segment size must be less than segment size")

    self.total_lines = total_lines
    self.segment_size = segment_size
    self.start_line = start_line
    self.min_segment_size = min_segment_size

    # Calculate number of segments
    remaining_lines = total_lines - start_line + 1
    self.num_segments = (remaining_lines + segment_size - 1) // segment_size
__iter__()

Iterate over line segments.

Yields:

Type Description
LineSegment

LineSegment containing start (inclusive) and end (exclusive) indices

Source code in src/tnh_scholar/text_processing/numbered_text.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def __iter__(self) -> Iterator["NumberedText.LineSegment"]:
    """
    Iterate over line segments.

    Yields:
        LineSegment containing start (inclusive) and end (exclusive) indices
    """
    current = self.start_line

    for i in range(self.num_segments):
        is_last_segment = i == self.num_segments - 1
        segment_end = min(current + self.segment_size, self.total_lines + 1)

        # Handle minimum segment size for last segment
        if (
            is_last_segment
            and self.min_segment_size is not None
            and segment_end - current < self.min_segment_size
            and i > 0
        ):
            # Merge with previous segment by not yielding
            break

        yield NumberedText.LineSegment(current, segment_end)
        current = segment_end
__getitem__(index)

Get line content by line number (1-based indexing).

Source code in src/tnh_scholar/text_processing/numbered_text.py
251
252
253
def __getitem__(self, index: int) -> str:
    """Get line content by line number (1-based indexing)."""
    return self.lines[self._to_internal_index(index)]
__init__(content=None, start=1, separator=':')

Initialize a numbered text document, detecting and preserving existing numbering.

Valid numbered text must have: - Sequential line numbers - Consistent separator character(s) - Every non-empty line must follow the numbering pattern

Parameters:

Name Type Description Default
content Optional[str]

Initial text content, if any

None
start int

Starting line number (used only if content isn't already numbered)

1
separator str

Separator between line numbers and content

':'

Examples:

>>> # Custom separators
>>> doc = NumberedText("1→First line\n2→Second line")
>>> doc.separator == "→"
True
>>> # Preserves starting number
>>> doc = NumberedText("5#First\n6#Second")
>>> doc.start == 5
True
>>> # Regular numbered list isn't treated as line numbers
>>> doc = NumberedText("1. First item\n2. Second item")
>>> doc.numbered_lines
['1: 1. First item', '2: 2. Second item']
Source code in src/tnh_scholar/text_processing/numbered_text.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def __init__(
    self, content: Optional[str] = None, start: int = 1, separator: str = ":"
) -> None:
    """
    Initialize a numbered text document, 
    detecting and preserving existing numbering.

    Valid numbered text must have:
    - Sequential line numbers
    - Consistent separator character(s)
    - Every non-empty line must follow the numbering pattern

    Args:
        content: Initial text content, if any
        start: Starting line number (used only if content isn't already numbered)
        separator: Separator between line numbers and content 
        (only if content isn't numbered)

    Examples:
        >>> # Custom separators
        >>> doc = NumberedText("1→First line\\n2→Second line")
        >>> doc.separator == "→"
        True

        >>> # Preserves starting number
        >>> doc = NumberedText("5#First\\n6#Second")
        >>> doc.start == 5
        True

        >>> # Regular numbered list isn't treated as line numbers
        >>> doc = NumberedText("1. First item\\n2. Second item")
        >>> doc.numbered_lines
        ['1: 1. First item', '2: 2. Second item']
    """

    self.lines: List[str] = []  # Declare lines here
    self.start: int = start  # Declare start with its type
    self.separator: str = separator  # and separator

    if not isinstance(content, str):
        raise ValueError("NumberedText requires string input.")

    if start < 1:  # enforce 1 based indexing.
        raise IndexError(
            "NumberedText: Numbered lines must begin on "
            "an integer great or equal to 1."
        )

    if not content:
        return

    # Analyze the text format
    is_numbered, detected_sep, start_num = get_numbered_format(content)

    format_info = get_numbered_format(content)

    if format_info.is_numbered:
        self.start = format_info.start_num  # type: ignore
        self.separator = format_info.separator  # type: ignore

        # Extract content by removing number and separator
        pattern = re.compile(rf"^\d+{re.escape(detected_sep)}") # type: ignore
        self.lines = []

        for line in content.splitlines():
            if line.strip():
                self.lines.append(pattern.sub("", line))
            else:
                self.lines.append(line)
    else:
        self.lines = content.splitlines()
        self.start = start
        self.separator = separator
__iter__()

Iterate over (line_number, line_content) pairs.

Source code in src/tnh_scholar/text_processing/numbered_text.py
247
248
249
def __iter__(self) -> Iterator[tuple[int, str]]:
    """Iterate over (line_number, line_content) pairs."""
    return iter((i, line) for i, line in enumerate(self.lines, self.start))
__len__()

Return the number of lines.

Source code in src/tnh_scholar/text_processing/numbered_text.py
243
244
245
def __len__(self) -> int:
    """Return the number of lines."""
    return len(self.lines)
__str__()

Return the numbered text representation.

Source code in src/tnh_scholar/text_processing/numbered_text.py
237
238
239
240
241
def __str__(self) -> str:
    """Return the numbered text representation."""
    return "\n".join(
        self._format_line(i, line) for i, line in enumerate(self.lines, self.start)
    )
append(text)

Append text, splitting into lines if needed.

Source code in src/tnh_scholar/text_processing/numbered_text.py
328
329
330
def append(self, text: str) -> None:
    """Append text, splitting into lines if needed."""
    self.lines.extend(text.splitlines())
from_file(path, **kwargs) classmethod

Create a NumberedText instance from a file.

Source code in src/tnh_scholar/text_processing/numbered_text.py
218
219
220
221
@classmethod
def from_file(cls, path: Path, **kwargs) -> "NumberedText":
    """Create a NumberedText instance from a file."""
    return cls(read_str_from_file(Path(path)), **kwargs)
get_line(line_num)

Get content of specified line number.

Source code in src/tnh_scholar/text_processing/numbered_text.py
255
256
257
def get_line(self, line_num: int) -> str:
    """Get content of specified line number."""
    return self[line_num]
get_lines(start, end)

Get content of line range, not inclusive of end line.

Source code in src/tnh_scholar/text_processing/numbered_text.py
267
268
269
def get_lines(self, start: int, end: int) -> List[str]:
    """Get content of line range, not inclusive of end line."""
    return self.lines[self._to_internal_index(start) : self._to_internal_index(end)]
get_numbered_line(line_num)

Get specified line with line number.

Source code in src/tnh_scholar/text_processing/numbered_text.py
262
263
264
265
def get_numbered_line(self, line_num: int) -> str:
    """Get specified line with line number."""
    idx = self._to_line_index(line_num)
    return self._format_line(idx, self[idx])
get_numbered_lines(start, end)
Source code in src/tnh_scholar/text_processing/numbered_text.py
271
272
273
274
275
def get_numbered_lines(self, start: int, end: int) -> List[str]:
    return [
        self._format_line(i + self._to_internal_index(start) + 1, line)
        for i, line in enumerate(self.get_lines(start, end))
    ]
get_numbered_segment(start, end)
Source code in src/tnh_scholar/text_processing/numbered_text.py
314
315
def get_numbered_segment(self, start: int, end: int) -> str:
    return "\n".join(self.get_numbered_lines(start, end))
get_segment(start, end)

return the segment from start line (inclusive) up to end line (exclusive)

Source code in src/tnh_scholar/text_processing/numbered_text.py
276
277
278
279
280
281
282
283
284
def get_segment(self, start: int, end: int) -> str:
    """return the segment from start line (inclusive) up to end line (exclusive)"""
    if start < self.start:
        raise IndexError(f"Start index {start} is before first line {self.start}")
    if end > len(self) + 1:
        raise IndexError(f"End index {end} is past last line {len(self)}")
    if start >= end:
        raise IndexError(f"Start index {start} must be less than end index {end}")
    return "\n".join(self.get_lines(start, end))
insert(line_num, text)

Insert text at specified line number. Assumes text is not empty.

Source code in src/tnh_scholar/text_processing/numbered_text.py
332
333
334
335
336
def insert(self, line_num: int, text: str) -> None:
    """Insert text at specified line number. Assumes text is not empty."""
    new_lines = text.splitlines()
    internal_idx = self._to_internal_index(line_num)
    self.lines[internal_idx:internal_idx] = new_lines
iter_segments(segment_size, min_segment_size=None)

Iterate over segments of the text with specified size.

Parameters:

Name Type Description Default
segment_size int

Number of lines per segment

required
min_segment_size Optional[int]

Optional minimum size for final segment. If specified, last segment will be merged with previous one if it would be smaller than this size.

None

Yields:

Type Description
LineSegment

LineSegment objects containing start and end line numbers

Example

text = NumberedText("line1\nline2\nline3\nline4\nline5") for segment in text.iter_segments(2): ... print(f"Lines {segment.start}-{segment.end}") Lines 1-3 Lines 3-5 Lines 5-6

Source code in src/tnh_scholar/text_processing/numbered_text.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def iter_segments(
    self, segment_size: int, min_segment_size: Optional[int] = None
) -> Iterator[LineSegment]:
    """
    Iterate over segments of the text with specified size.

    Args:
        segment_size: Number of lines per segment
        min_segment_size: Optional minimum size for final segment.
            If specified, last segment will be merged with previous one
            if it would be smaller than this size.

    Yields:
        LineSegment objects containing start and end line numbers

    Example:
        >>> text = NumberedText("line1\\nline2\\nline3\\nline4\\nline5")
        >>> for segment in text.iter_segments(2):
        ...     print(f"Lines {segment.start}-{segment.end}")
        Lines 1-3
        Lines 3-5
        Lines 5-6
    """
    iterator = self.SegmentIterator(
        len(self), segment_size, self.start, min_segment_size
    )
    return iter(iterator)
remove_whitespace()

Remove leading and trailing whitespace from all lines.

Source code in src/tnh_scholar/text_processing/numbered_text.py
341
342
343
def remove_whitespace(self) -> None:
    """Remove leading and trailing whitespace from all lines."""
    self.lines = [line.strip() for line in self.lines]
reset_numbering()
Source code in src/tnh_scholar/text_processing/numbered_text.py
338
339
def reset_numbering(self):
    self.start = 1
save(path, numbered=True)

Save document to file.

Parameters:

Name Type Description Default
path Path

Output file path

required
numbered bool

Whether to save with line numbers (default: True)

True
Source code in src/tnh_scholar/text_processing/numbered_text.py
317
318
319
320
321
322
323
324
325
326
def save(self, path: Path, numbered: bool = True) -> None:
    """
    Save document to file.

    Args:
        path: Output file path
        numbered: Whether to save with line numbers (default: True)
    """
    content = str(self) if numbered else "\n".join(self.lines)
    write_str_to_file(path, content)
get_numbered_format(text)

Analyze text to determine if it follows a consistent line numbering format.

Valid formats have: - Sequential numbers starting from some value - Consistent separator character(s) - Every line must follow the format

Parameters:

Name Type Description Default
text str

Text to analyze

required

Returns:

Type Description
NumberedFormat

Tuple of (is_numbered, separator, start_number)

Examples:

>>> _analyze_numbered_format("1→First\n2→Second")
(True, "→", 1)
>>> _analyze_numbered_format("1. First")  # Numbered list format
(False, None, None)
>>> _analyze_numbered_format("5#Line\n6#Other")
(True, "#", 5)
Source code in src/tnh_scholar/text_processing/numbered_text.py
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
def get_numbered_format(text: str) -> NumberedFormat:
    """
    Analyze text to determine if it follows a consistent line numbering format.

    Valid formats have:
    - Sequential numbers starting from some value
    - Consistent separator character(s)
    - Every line must follow the format

    Args:
        text: Text to analyze

    Returns:
        Tuple of (is_numbered, separator, start_number)

    Examples:
        >>> _analyze_numbered_format("1→First\\n2→Second")
        (True, "→", 1)
        >>> _analyze_numbered_format("1. First")  # Numbered list format
        (False, None, None)
        >>> _analyze_numbered_format("5#Line\\n6#Other")
        (True, "#", 5)
    """
    if not text.strip():
        return NumberedFormat(False)

    lines = [line for line in text.splitlines() if line.strip()]
    if not lines:
        return NumberedFormat(False)

    # Try to detect pattern from first line
    SEPARATOR_PATTERN = r"[^\w\s.]"  # not (word char or whitespace or period)
    first_match = re.match(rf"^(\d+)({SEPARATOR_PATTERN})(.*?)$", lines[0])

    if not first_match:
        return NumberedFormat(False)
    try:
        return _check_line_structure(first_match, lines)
    except (ValueError, AttributeError):
        return NumberedFormat(False)

text_processing

clean_text(text, newline=False)

Cleans a given text by replacing specific unwanted characters such as tab, and non-breaking spaces with regular spaces.

This function takes a string as input and applies replacements based on a predefined mapping of characters to replace.

Parameters:

Name Type Description Default
text str

The text to be cleaned.

required

Returns:

Name Type Description
str

The cleaned text with unwanted characters replaced by spaces.

Example

text = "This is\n an example\ttext with\xa0extra spaces." clean_text(text) 'This is an example text with extra spaces.'

Source code in src/tnh_scholar/text_processing/text_processing.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def clean_text(text, newline=False):
    """
    Cleans a given text by replacing specific unwanted characters such as
    tab, and non-breaking spaces with regular spaces.

    This function takes a string as input and applies replacements
    based on a predefined mapping of characters to replace.

    Args:
        text (str): The text to be cleaned.

    Returns:
        str: The cleaned text with unwanted characters replaced by spaces.

    Example:
        >>> text = "This is\\n an example\\ttext with\\xa0extra spaces."
        >>> clean_text(text)
        'This is an example text with extra spaces.'

    """
    # Define a mapping of characters to replace
    replace_map = {
        "\t": " ",  # Replace tabs with space
        "\xa0": " ",  # Replace non-breaking space with regular space
        # Add more replacements as needed
    }

    if newline:
        replace_map["\n"] = ""  # remove newlines

    # Loop through the replace map and replace each character
    for old_char, new_char in replace_map.items():
        text = text.replace(old_char, new_char)

    return text.strip()  # Ensure any leading/trailing spaces are removed
normalize_newlines(text, spacing=2)
Normalize newline blocks in the input text by reducing consecutive newlines
to the specified number of newlines for consistent readability and formatting.

Parameters:
----------
text : str
    The input text containing inconsistent newline spacing.
spacing : int, optional
    The number of newlines to insert between lines. Defaults to 2.

Returns:
-------
str
    The text with consecutive newlines reduced to the specified number of newlines.

Example:
--------
>>> raw_text = "Heading

Paragraph text 1 Paragraph text 2

" >>> normalize_newlines(raw_text, spacing=2) 'Heading

Paragraph text 1

Paragraph text 2

'

Source code in src/tnh_scholar/text_processing/text_processing.py
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def normalize_newlines(text: str, spacing: int = 2) -> str:
    """
    Normalize newline blocks in the input text by reducing consecutive newlines
    to the specified number of newlines for consistent readability and formatting.

    Parameters:
    ----------
    text : str
        The input text containing inconsistent newline spacing.
    spacing : int, optional
        The number of newlines to insert between lines. Defaults to 2.

    Returns:
    -------
    str
        The text with consecutive newlines reduced to the specified number of newlines.

    Example:
    --------
    >>> raw_text = "Heading\n\n\nParagraph text 1\nParagraph text 2\n\n\n"
    >>> normalize_newlines(raw_text, spacing=2)
    'Heading\n\nParagraph text 1\n\nParagraph text 2\n\n'
    """
    # Replace one or more newlines with the desired number of newlines
    newlines = "\n" * spacing
    return re.sub(r"\n{1,}", newlines, text)

tools

Internal helper utilities for dev workflows.

notebook_prep

Utilities for maintaining paired *_local.ipynb notebooks.

EXCLUDED_PARTS = {'.ipynb_checkpoints'} module-attribute
prep_notebooks(directory, dry_run=True)

Create *_local notebooks and strip outputs from originals.

Parameters

directory: Directory whose notebooks will be processed. dry_run: When True only report pending work without copying files or invoking nbconvert.

Source code in src/tnh_scholar/tools/notebook_prep.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def prep_notebooks(directory: Path | str, dry_run: bool = True) -> bool:
    """Create *_local notebooks and strip outputs from originals.

    Parameters
    ----------
    directory:
        Directory whose notebooks will be processed.
    dry_run:
        When ``True`` only report pending work without copying files or invoking
        ``nbconvert``.
    """

    directory = Path(directory).expanduser()
    if not directory.exists():
        print(f"Directory not found: {directory}")
        return False

    notebooks = list(_iter_source_notebooks(directory))
    print(
        f"Found {len(notebooks)} notebooks to process in {directory}. "
        "Ignoring all checkpoint and *_local notebooks."
    )

    for nb_path in notebooks:
        local_path = nb_path.parent / f"{nb_path.stem}_local{nb_path.suffix}"

        if local_path.exists():
            print(f"No action required: local copy of notebook exists: {local_path}")
            continue
        if dry_run:
            print(f"Would copy: {nb_path} -> {local_path}")
        else:
            print(f"Copying: {nb_path} -> {local_path}")
            shutil.copy2(nb_path, local_path)

        if dry_run:
            print(f"Would strip outputs from: {nb_path}")
            continue

        print(f"Stripping outputs from: {nb_path}")
        subprocess.run(
            [
                "jupyter",
                "nbconvert",
                "--ClearOutputPreprocessor.enabled=True",
                "--inplace",
                str(nb_path),
            ],
            check=True,
        )

    return True

tree_builder

Helpers for generating directory-tree text files.

build_tree(root_dir, src_dir=None)

Generate directory trees for the project and optionally its source directory.

Source code in src/tnh_scholar/tools/tree_builder.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def build_tree(root_dir: Path, src_dir: Optional[Path] = None) -> None:
    """Generate directory trees for the project and optionally its source directory."""
    if shutil.which("tree") is None:
        raise FileNotFoundError(
            "The 'tree' command is not found in the system PATH. Please install it first."
        )

    if not root_dir.exists() or not root_dir.is_dir():
        raise FileNotFoundError(
            f"The root directory '{root_dir}' does not exist or is not a directory."
        )

    project_tree_output = root_dir / "project_directory_tree.txt"
    subprocess.run(
        ["tree", "--gitignore", str(root_dir), "-o", str(project_tree_output)],
        check=True,
    )

    if src_dir:
        if not src_dir.exists() or not src_dir.is_dir():
            raise FileNotFoundError(
                f"The source directory '{src_dir}' does not exist or is not a directory."
            )
        src_tree_output = root_dir / "src_directory_tree.txt"
        subprocess.run(
            ["tree", "--gitignore", str(src_dir), "-o", str(src_tree_output)],
            check=True,
        )

utils

__all__ = ['copy_files_with_regex', 'ensure_directory_exists', 'ensure_directory_writable', 'iterate_subdir', 'path_as_str', 'read_str_from_file', 'sanitize_filename', 'to_slug', 'write_str_to_file', 'load_json_into_model', 'load_jsonl_to_dict', 'save_model_to_json', 'get_language_code_from_text', 'get_language_from_code', 'get_language_name_from_text', 'ExpectedTimeTQDM', 'TimeProgress', 'TimeMs', 'TNHAudioSegment', 'convert_ms_to_sec', 'convert_sec_to_ms', 'get_user_confirmation', 'check_ocr_env', 'check_openai_env'] module-attribute

ExpectedTimeTQDM

A context manager for a time-based tqdm progress bar with optional delay.

  • 'expected_time': number of seconds we anticipate the task might take.
  • 'display_interval': how often (seconds) to refresh the bar.
  • 'desc': a short description for the bar.
  • 'delay_start': how many seconds to wait (sleep) before we even create/start the bar.

If the task finishes before 'delay_start' has elapsed, the bar may never appear.

Source code in src/tnh_scholar/utils/progress_utils.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class ExpectedTimeTQDM:
    """
    A context manager for a time-based tqdm progress bar with optional delay.

    - 'expected_time': number of seconds we anticipate the task might take.
    - 'display_interval': how often (seconds) to refresh the bar.
    - 'desc': a short description for the bar.
    - 'delay_start': how many seconds to wait (sleep) before we even create/start the bar.

    If the task finishes before 'delay_start' has elapsed, the bar may never appear.
    """

    def __init__(
        self,
        expected_time: float,
        display_interval: float = 0.5,
        desc: str = "Time-based Progress",
        delay_start: float = 1.0,
    ) -> None:
        self.expected_time = round(expected_time)  # use nearest second.
        self.display_interval = display_interval
        self.desc = desc
        self.delay_start = delay_start

        self._stop_event = threading.Event()
        self._pbar = None  # We won't create the bar until after 'delay_start'
        self._start_time = None

    def __enter__(self):
        # Record the start time for reference
        self._start_time = time.time()

        # Spawn the background thread; it will handle waiting and then creating/updating the bar
        self._thread = threading.Thread(target=self._update_bar, daemon=True)
        self._thread.start()

        return self

    def _update_bar(self):
        # 1) Delay so warnings/logs can appear before the bar
        if self.delay_start > 0:
            time.sleep(self.delay_start)

        # 2) Create the tqdm bar (only now does it appear)
        self._pbar = tqdm(
            total=self.expected_time, desc=self.desc, unit="sec", bar_format=BAR_FORMAT
        )

        # 3) Update until told to stop
        while not self._stop_event.is_set():
            elapsed = time.time() - self._start_time
            current_value = min(elapsed, self.expected_time)
            if self._pbar:
                self._pbar.n = round(current_value)
                self._pbar.refresh()
            time.sleep(self.display_interval)

    def __exit__(self, exc_type, exc_value, traceback):
        # Signal the thread to stop
        self._stop_event.set()
        self._thread.join()

        # If the bar was actually created (i.e., we didn't finish too quickly),
        # do a final update and close
        if self._pbar:
            elapsed = time.time() - self._start_time
            self._pbar.n = round(min(elapsed, self.expected_time))
            self._pbar.refresh()
            self._pbar.close()

    import time
delay_start = delay_start instance-attribute
desc = desc instance-attribute
display_interval = display_interval instance-attribute
expected_time = round(expected_time) instance-attribute
__enter__()
Source code in src/tnh_scholar/utils/progress_utils.py
41
42
43
44
45
46
47
48
49
def __enter__(self):
    # Record the start time for reference
    self._start_time = time.time()

    # Spawn the background thread; it will handle waiting and then creating/updating the bar
    self._thread = threading.Thread(target=self._update_bar, daemon=True)
    self._thread.start()

    return self
__exit__(exc_type, exc_value, traceback)
Source code in src/tnh_scholar/utils/progress_utils.py
70
71
72
73
74
75
76
77
78
79
80
81
def __exit__(self, exc_type, exc_value, traceback):
    # Signal the thread to stop
    self._stop_event.set()
    self._thread.join()

    # If the bar was actually created (i.e., we didn't finish too quickly),
    # do a final update and close
    if self._pbar:
        elapsed = time.time() - self._start_time
        self._pbar.n = round(min(elapsed, self.expected_time))
        self._pbar.refresh()
        self._pbar.close()
__init__(expected_time, display_interval=0.5, desc='Time-based Progress', delay_start=1.0)
Source code in src/tnh_scholar/utils/progress_utils.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    expected_time: float,
    display_interval: float = 0.5,
    desc: str = "Time-based Progress",
    delay_start: float = 1.0,
) -> None:
    self.expected_time = round(expected_time)  # use nearest second.
    self.display_interval = display_interval
    self.desc = desc
    self.delay_start = delay_start

    self._stop_event = threading.Event()
    self._pbar = None  # We won't create the bar until after 'delay_start'
    self._start_time = None

TNHAudioSegment

Source code in src/tnh_scholar/utils/tnh_audio_segment.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class TNHAudioSegment:
    def __init__(self, segment: _AudioSegment):
        self._segment = segment

    @staticmethod
    def from_file(file: str | Path | BytesIO, format: str | None = None, **kwargs) -> "TNHAudioSegment":
        """
        Wrapper: Load an audio file into a TNHAudioSegment.

        Args:
            file: Path to the audio file.
            format: Optional audio format (e.g., 'mp3', 'wav'). If None, pydub will attempt to infer it.
            **kwargs: Additional keyword arguments passed to pydub.AudioSegment.from_file.

        Returns:
            TNHAudioSegment instance containing the loaded audio.
        """
        return TNHAudioSegment(_AudioSegment.from_file(file, format=format, **kwargs))

    def export(self, out_f: str | BinaryIO, format: str, **kwargs) -> None:
        """
        Wrapper: Export the audio segment to a file-like object or file path.

        Args:
            out_f: File path or file-like object to write the audio data to.
            format: Audio format (e.g., 'mp3', 'wav').
            **kwargs: Additional keyword arguments passed to pydub.AudioSegment.export.
        """
        self._segment.export(out_f, format=format, **kwargs)

    @staticmethod
    def silent(duration: int) -> "TNHAudioSegment":
        return TNHAudioSegment(_AudioSegment.silent(duration=duration))

    @staticmethod
    def empty() -> "TNHAudioSegment":
        return TNHAudioSegment(_AudioSegment.empty())

    def __getitem__(self, key: int | slice) -> "TNHAudioSegment":
        return TNHAudioSegment(self._segment[key]) # type: ignore

    def __add__(self, other: "TNHAudioSegment") -> "TNHAudioSegment":
        return TNHAudioSegment(self._segment + other._segment)

    def __iadd__(self, other: "TNHAudioSegment") -> "TNHAudioSegment":
        self._segment = self._segment + other._segment
        return self

    def __len__(self) -> int:
        return len(self._segment)

    # Add more methods as needed, e.g., export, from_file, etc.

    @property
    def raw(self) -> _AudioSegment:
        """Access the underlying pydub.AudioSegment if needed."""
        return self._segment
raw property

Access the underlying pydub.AudioSegment if needed.

__add__(other)
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
66
67
def __add__(self, other: "TNHAudioSegment") -> "TNHAudioSegment":
    return TNHAudioSegment(self._segment + other._segment)
__getitem__(key)
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
63
64
def __getitem__(self, key: int | slice) -> "TNHAudioSegment":
    return TNHAudioSegment(self._segment[key]) # type: ignore
__iadd__(other)
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
69
70
71
def __iadd__(self, other: "TNHAudioSegment") -> "TNHAudioSegment":
    self._segment = self._segment + other._segment
    return self
__init__(segment)
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
26
27
def __init__(self, segment: _AudioSegment):
    self._segment = segment
__len__()
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
73
74
def __len__(self) -> int:
    return len(self._segment)
empty() staticmethod
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
59
60
61
@staticmethod
def empty() -> "TNHAudioSegment":
    return TNHAudioSegment(_AudioSegment.empty())
export(out_f, format, **kwargs)

Wrapper: Export the audio segment to a file-like object or file path.

Parameters:

Name Type Description Default
out_f str | BinaryIO

File path or file-like object to write the audio data to.

required
format str

Audio format (e.g., 'mp3', 'wav').

required
**kwargs

Additional keyword arguments passed to pydub.AudioSegment.export.

{}
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
44
45
46
47
48
49
50
51
52
53
def export(self, out_f: str | BinaryIO, format: str, **kwargs) -> None:
    """
    Wrapper: Export the audio segment to a file-like object or file path.

    Args:
        out_f: File path or file-like object to write the audio data to.
        format: Audio format (e.g., 'mp3', 'wav').
        **kwargs: Additional keyword arguments passed to pydub.AudioSegment.export.
    """
    self._segment.export(out_f, format=format, **kwargs)
from_file(file, format=None, **kwargs) staticmethod

Wrapper: Load an audio file into a TNHAudioSegment.

Parameters:

Name Type Description Default
file str | Path | BytesIO

Path to the audio file.

required
format str | None

Optional audio format (e.g., 'mp3', 'wav'). If None, pydub will attempt to infer it.

None
**kwargs

Additional keyword arguments passed to pydub.AudioSegment.from_file.

{}

Returns:

Type Description
TNHAudioSegment

TNHAudioSegment instance containing the loaded audio.

Source code in src/tnh_scholar/utils/tnh_audio_segment.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
@staticmethod
def from_file(file: str | Path | BytesIO, format: str | None = None, **kwargs) -> "TNHAudioSegment":
    """
    Wrapper: Load an audio file into a TNHAudioSegment.

    Args:
        file: Path to the audio file.
        format: Optional audio format (e.g., 'mp3', 'wav'). If None, pydub will attempt to infer it.
        **kwargs: Additional keyword arguments passed to pydub.AudioSegment.from_file.

    Returns:
        TNHAudioSegment instance containing the loaded audio.
    """
    return TNHAudioSegment(_AudioSegment.from_file(file, format=format, **kwargs))
silent(duration) staticmethod
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
55
56
57
@staticmethod
def silent(duration: int) -> "TNHAudioSegment":
    return TNHAudioSegment(_AudioSegment.silent(duration=duration))

TimeMs

Bases: int

Lightweight representation of a time interval or timestamp in milliseconds. Allows negative values.

Source code in src/tnh_scholar/utils/timing_utils.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class TimeMs(int):
    """
    Lightweight representation of a time interval or timestamp in milliseconds.
    Allows negative values.
    """

    def __new__(cls, ms: Union[int, float, "TimeMs"]):
        if isinstance(ms, TimeMs):
            value = int(ms)
        elif isinstance(ms, (int, float)):
            if not math.isfinite(ms):
                raise ValueError("ms must be a finite number")
            value = round(ms)
        else:
            raise TypeError(f"ms must be a number or TimeMs, got {type(ms).__name__}")
        return int.__new__(cls, value)

    @classmethod
    def from_seconds(cls, seconds: int | float) -> "TimeMs":
        return cls(round(seconds * 1000))

    def to_ms(self) -> int:
        return int(self)

    def to_seconds(self) -> float:
        return float(self) / 1000

    @classmethod
    def __get_pydantic_core_schema__(cls, source_type, handler: GetCoreSchemaHandler):
        return core_schema.with_info_plain_validator_function(
            cls._validate,
            serialization=core_schema.plain_serializer_function_ser_schema(lambda v: int(v)),
        )

    @classmethod
    def _validate(cls, value, info):
        """
        Pydantic core validator for TimeMs.

        Args:
            value: The value to validate.
            info: Pydantic core schema info (unused).

        Returns:
            TimeMs: Validated TimeMs instance.
        """
        return cls(value)

    def __add__(self, other):
        return TimeMs(int(self) + int(other))

    def __radd__(self, other):
        return TimeMs(int(other) + int(self))

    def __sub__(self, other):
        return TimeMs(int(self) - int(other))

    def __rsub__(self, other):
        return TimeMs(int(self) - int(other))

    def __repr__(self) -> str:
        return f"TimeMs({self.to_seconds():.3f}s)"
__add__(other)
Source code in src/tnh_scholar/utils/timing_utils.py
62
63
def __add__(self, other):
    return TimeMs(int(self) + int(other))
__get_pydantic_core_schema__(source_type, handler) classmethod
Source code in src/tnh_scholar/utils/timing_utils.py
41
42
43
44
45
46
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler: GetCoreSchemaHandler):
    return core_schema.with_info_plain_validator_function(
        cls._validate,
        serialization=core_schema.plain_serializer_function_ser_schema(lambda v: int(v)),
    )
__new__(ms)
Source code in src/tnh_scholar/utils/timing_utils.py
20
21
22
23
24
25
26
27
28
29
def __new__(cls, ms: Union[int, float, "TimeMs"]):
    if isinstance(ms, TimeMs):
        value = int(ms)
    elif isinstance(ms, (int, float)):
        if not math.isfinite(ms):
            raise ValueError("ms must be a finite number")
        value = round(ms)
    else:
        raise TypeError(f"ms must be a number or TimeMs, got {type(ms).__name__}")
    return int.__new__(cls, value)
__radd__(other)
Source code in src/tnh_scholar/utils/timing_utils.py
65
66
def __radd__(self, other):
    return TimeMs(int(other) + int(self))
__repr__()
Source code in src/tnh_scholar/utils/timing_utils.py
74
75
def __repr__(self) -> str:
    return f"TimeMs({self.to_seconds():.3f}s)"
__rsub__(other)
Source code in src/tnh_scholar/utils/timing_utils.py
71
72
def __rsub__(self, other):
    return TimeMs(int(self) - int(other))
__sub__(other)
Source code in src/tnh_scholar/utils/timing_utils.py
68
69
def __sub__(self, other):
    return TimeMs(int(self) - int(other))
from_seconds(seconds) classmethod
Source code in src/tnh_scholar/utils/timing_utils.py
31
32
33
@classmethod
def from_seconds(cls, seconds: int | float) -> "TimeMs":
    return cls(round(seconds * 1000))
to_ms()
Source code in src/tnh_scholar/utils/timing_utils.py
35
36
def to_ms(self) -> int:
    return int(self)
to_seconds()
Source code in src/tnh_scholar/utils/timing_utils.py
38
39
def to_seconds(self) -> float:
    return float(self) / 1000

TimeProgress

A context manager for a time-based progress display using dots.

The display updates once per second, printing a dot and showing: - Expected time (if provided) - Elapsed time (always displayed)

Example:

import time with ExpectedTimeProgress(expected_time=60, desc="Transcribing..."): ... time.sleep(5) # Simulate work [Expected Time: 1:00, Elapsed Time: 0:05] .....

Parameters:

Name Type Description Default
expected_time Optional[float]

Expected time in seconds. Optional.

None
display_interval float

How often to print a dot (seconds).

1.0
desc str

Description to display alongside the progress.

''
Source code in src/tnh_scholar/utils/progress_utils.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class TimeProgress:
    """
    A context manager for a time-based progress display using dots.

    The display updates once per second, printing a dot and showing:
    - Expected time (if provided)
    - Elapsed time (always displayed)

    Example:
    >>> import time
    >>> with ExpectedTimeProgress(expected_time=60, desc="Transcribing..."):
    ...     time.sleep(5)  # Simulate work
    [Expected Time: 1:00, Elapsed Time: 0:05] .....

    Args:
        expected_time (Optional[float]): Expected time in seconds. Optional.
        display_interval (float): How often to print a dot (seconds).
        desc (str): Description to display alongside the progress.
    """

    def __init__(
        self,
        expected_time: Optional[float] = None,
        display_interval: float = 1.0,
        desc: str = "",
    ):
        self.expected_time = expected_time
        self.display_interval = display_interval
        self._stop_event = threading.Event()
        self._start_time = None
        self._thread = None
        self.desc = desc
        self._last_length = 0  # To keep track of the last printed line length

    def __enter__(self):
        # Record the start time
        self._start_time = time.time()

        # Spawn the background thread
        self._thread = threading.Thread(target=self._print_progress, daemon=True)
        self._thread.start()

        return self

    def _print_progress(self):
        """
        Continuously prints progress alternating between | and — along with elapsed/expected time.
        """
        symbols = ["|", "/", "—", "\\"]  # Symbols to alternate between
        symbol_index = 0  # Keep track of the current symbol

        while not self._stop_event.is_set():
            elapsed = time.time() - self._start_time

            # Format elapsed time as mm:ss
            elapsed_str = self._format_time(elapsed)

            # Format expected time if provided
            if self.expected_time is not None:
                expected_str = self._format_time(self.expected_time)
                header = f"{self.desc} [Expected Time: {expected_str}, Elapsed Time: {elapsed_str}]"
            else:
                header = f"{self.desc} [Elapsed Time: {elapsed_str}]"

            # Get the current symbol for the spinner
            spinner = symbols[symbol_index]

            # Construct the line with the spinner
            line = f"\r{header} {spinner}"

            # Write to stdout
            sys.stdout.write(line)
            sys.stdout.flush()

            # Update the symbol index to alternate
            symbol_index = (symbol_index + 1) % len(symbols)

            # Sleep before next update
            time.sleep(self.display_interval)

        # Clear the spinner after finishing
        sys.stdout.write("\r" + " " * len(line) + "\r")
        sys.stdout.flush()

    def __exit__(self, exc_type, exc_value, traceback):
        # Signal the thread to stop
        self._stop_event.set()
        self._thread.join()

        # Final elapsed time
        elapsed = time.time() - self._start_time
        elapsed_str = self._format_time(elapsed)

        # Construct the final line
        if self.expected_time is not None:
            expected_str = self._format_time(self.expected_time)
            final_header = f"{self.desc} [Expected Time: {expected_str}, Elapsed Time: {elapsed_str}]"
        else:
            final_header = f"{self.desc} [Elapsed Time: {elapsed_str}]"

        # Final dots
        final_line = f"\r{final_header}"

        # Clear the line and move to the next line
        padding = " " * max(self._last_length - len(final_line), 0)
        sys.stdout.write(final_line + padding + "\n")
        sys.stdout.flush()

    @staticmethod
    def _format_time(seconds: float) -> str:
        """
        Converts seconds to a formatted string (mm:ss).
        """
        minutes = int(seconds // 60)
        seconds = int(seconds % 60)
        return f"{minutes}:{seconds:02}"
desc = desc instance-attribute
display_interval = display_interval instance-attribute
expected_time = expected_time instance-attribute
__enter__()
Source code in src/tnh_scholar/utils/progress_utils.py
122
123
124
125
126
127
128
129
130
def __enter__(self):
    # Record the start time
    self._start_time = time.time()

    # Spawn the background thread
    self._thread = threading.Thread(target=self._print_progress, daemon=True)
    self._thread.start()

    return self
__exit__(exc_type, exc_value, traceback)
Source code in src/tnh_scholar/utils/progress_utils.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def __exit__(self, exc_type, exc_value, traceback):
    # Signal the thread to stop
    self._stop_event.set()
    self._thread.join()

    # Final elapsed time
    elapsed = time.time() - self._start_time
    elapsed_str = self._format_time(elapsed)

    # Construct the final line
    if self.expected_time is not None:
        expected_str = self._format_time(self.expected_time)
        final_header = f"{self.desc} [Expected Time: {expected_str}, Elapsed Time: {elapsed_str}]"
    else:
        final_header = f"{self.desc} [Elapsed Time: {elapsed_str}]"

    # Final dots
    final_line = f"\r{final_header}"

    # Clear the line and move to the next line
    padding = " " * max(self._last_length - len(final_line), 0)
    sys.stdout.write(final_line + padding + "\n")
    sys.stdout.flush()
__init__(expected_time=None, display_interval=1.0, desc='')
Source code in src/tnh_scholar/utils/progress_utils.py
108
109
110
111
112
113
114
115
116
117
118
119
120
def __init__(
    self,
    expected_time: Optional[float] = None,
    display_interval: float = 1.0,
    desc: str = "",
):
    self.expected_time = expected_time
    self.display_interval = display_interval
    self._stop_event = threading.Event()
    self._start_time = None
    self._thread = None
    self.desc = desc
    self._last_length = 0  # To keep track of the last printed line length

check_ocr_env(output=True)

Check OCR processing requirements.

Source code in src/tnh_scholar/utils/validate.py
57
58
59
def check_ocr_env(output: bool = True) -> bool:
    """Check OCR processing requirements."""
    return check_env(OCR_ENV_VARS, "OCR processing", output=output)

check_openai_env(output=True)

Check OpenAI API requirements.

Source code in src/tnh_scholar/utils/validate.py
53
54
55
def check_openai_env(output: bool = True) -> bool:
    """Check OpenAI API requirements."""
    return check_env(OPENAI_ENV_VARS, "OpenAI API access", output=output)

convert_ms_to_sec(ms)

Convert time from milliseconds (int) to seconds (float).

Source code in src/tnh_scholar/utils/timing_utils.py
83
84
85
def convert_ms_to_sec(ms: int) -> float:
    """Convert time from milliseconds (int) to seconds (float)."""
    return float(ms / 1000)

convert_sec_to_ms(val)

Convert seconds to milliseconds, rounding to the nearest integer.

Source code in src/tnh_scholar/utils/timing_utils.py
77
78
79
80
81
def convert_sec_to_ms(val: float) -> int:
    """ 
    Convert seconds to milliseconds, rounding to the nearest integer.
    """
    return round(val * 1000)

copy_files_with_regex(source_dir, destination_dir, regex_patterns, preserve_structure=True)

Copies files from subdirectories one level down in the source directory to the destination directory if they match any regex pattern. Optionally preserves the directory structure.

Parameters:

Name Type Description Default
source_dir Path

Path to the source directory to search files in.

required
destination_dir Path

Path to the destination directory where files will be copied.

required
regex_patterns list[str]

List of regex patterns to match file names.

required
preserve_structure bool

Whether to preserve the directory structure. Defaults to True.

True

Raises:

Type Description
ValueError

If the source directory does not exist or is not a directory.

Example

copy_files_with_regex( ... source_dir=Path("/path/to/source"), ... destination_dir=Path("/path/to/destination"), ... regex_patterns=[r'..txt$', r'..log$'], ... preserve_structure=True ... )

Source code in src/tnh_scholar/utils/file_utils.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def copy_files_with_regex(
    source_dir: Path,
    destination_dir: Path,
    regex_patterns: list[str],
    preserve_structure: bool = True,
) -> None:
    """
    Copies files from subdirectories one level down in the source directory to 
    the destination directory if they match any regex pattern. Optionally preserves the 
    directory structure.

    Args:
        source_dir (Path): Path to the source directory to search files in.
        destination_dir (Path): Path to the destination directory where files will be 
            copied.
        regex_patterns (list[str]): List of regex patterns to match file names.
        preserve_structure (bool): Whether to preserve the directory structure. 
            Defaults to True.

    Raises:
        ValueError: If the source directory does not exist or is not a directory.

    Example:
        >>> copy_files_with_regex(
        ...     source_dir=Path("/path/to/source"),
        ...     destination_dir=Path("/path/to/destination"),
        ...     regex_patterns=[r'.*\\.txt$', r'.*\\.log$'],
        ...     preserve_structure=True
        ... )
    """
    if not source_dir.is_dir():
        raise ValueError(
            f"The source directory {source_dir} does not exist or is not a directory."
        )

    if not destination_dir.exists():
        destination_dir.mkdir(parents=True, exist_ok=True)

    # Compile regex patterns for efficiency
    compiled_patterns = [re.compile(pattern) for pattern in regex_patterns]

    # Process only one level down
    for subdir in source_dir.iterdir():
        if subdir.is_dir():  # Only process subdirectories
            print(f"processing {subdir}:")
            for file_path in subdir.iterdir():  # Only files in this subdirectory
                if file_path.is_file():
                    print(f"checking file: {file_path.name}")
                    # Check if the file matches any of the regex patterns
                    if any(
                        pattern.match(file_path.name) for pattern in compiled_patterns
                    ):
                        if preserve_structure:
                            # Construct the target path, preserving relative structure
                            relative_path = (
                                subdir.relative_to(source_dir) / file_path.name
                            )
                            target_path = destination_dir / relative_path
                            target_path.parent.mkdir(parents=True, exist_ok=True)
                        else:
                            # Put directly in destination without subdirectory structure
                            target_path = destination_dir / file_path.name

                        shutil.copy2(file_path, target_path)
                        print(f"Copied: {file_path} -> {target_path}")

ensure_directory_exists(dir_path)

Create directory if it doesn't exist.

Parameters:

Name Type Description Default
dir_path Path

Directory path to ensure exists.

required

Returns:

Name Type Description
bool bool

True if the directory exists or was created successfully, False otherwise.

Source code in src/tnh_scholar/utils/file_utils.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def ensure_directory_exists(dir_path: Path) -> bool:
    """
    Create directory if it doesn't exist.

    Args:
        dir_path (Path): Directory path to ensure exists.

    Returns:
        bool: True if the directory exists or was created successfully, False otherwise.
    """
    # No exception handling here. 
    # If exceptions occur let them propagate. 
    # Prototype code.

    dir_path.mkdir(parents=True, exist_ok=True)
    return True

ensure_directory_writable(dir_path)

Ensure the directory exists and is writable. Creates the directory if it does not exist.

Parameters:

Name Type Description Default
dir_path Path

Directory to verify or create.

required

Raises:

Type Description
ValueError

If the directory cannot be created or is not writable.

TypeError

If the provided path is not a Path instance.

Source code in src/tnh_scholar/utils/file_utils.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def ensure_directory_writable(dir_path: Path) -> None:
    """
    Ensure the directory exists and is writable.
    Creates the directory if it does not exist.

    Args:
        dir_path (Path): Directory to verify or create.

    Raises:
        ValueError: If the directory cannot be created or is not writable.
        TypeError: If the provided path is not a Path instance.
    """
    if not isinstance(dir_path, Path):
        raise TypeError("dir_path must be a pathlib.Path instance")

    # Ensure directory exists first
    ensure_directory_exists(dir_path)

    # Check writability safely using NamedTemporaryFile
    try:
        with tempfile.NamedTemporaryFile(dir=dir_path, prefix=".writability_check_", delete=True) as tmp:
            tmp.write(b"test")
            tmp.flush()
    except Exception as e:
        raise ValueError(f"Directory is not writable: {dir_path}") from e

get_language_code_from_text(text)

Detect the language of the provided text using langdetect.

Parameters:

Name Type Description Default
text str

Text to analyze

      code or 'name' for full English language name
required

Returns:

Name Type Description
str str

return result 'code' ISO 639-1 for detected language.

Raises:

Type Description
ValueError

If text is empty or invalid

Source code in src/tnh_scholar/utils/lang.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def get_language_code_from_text(text: str) -> str:
    """
    Detect the language of the provided text using langdetect.

    Args:
        text: Text to analyze

                      code or 'name' for full English language name

    Returns:
        str: return result 'code' ISO 639-1 for detected language.

    Raises:
        ValueError: If text is empty or invalid
    """

    if not text or text.isspace():
        raise ValueError("Input text cannot be empty")

    sample = _get_sample_text(text)

    try:
        return detect(sample)
    except LangDetectException:
        logger.warning("Language could not be detected in get_language().")
        return "un"

get_language_from_code(code)

Source code in src/tnh_scholar/utils/lang.py
40
41
42
43
44
def get_language_from_code(code: str):
    if language := pycountry.languages.get(alpha_2=code):
        return language.name
    logger.warning(f"No language name found for code: {code}")
    return "Unknown"

get_language_name_from_text(text)

Source code in src/tnh_scholar/utils/lang.py
36
37
def get_language_name_from_text(text: str) -> str:
    return get_language_from_code(get_language_code_from_text(text))

get_user_confirmation(prompt, default=True)

Prompt the user for a yes/no confirmation with single-character input. Cross-platform implementation. Returns True if 'y' is entered, and False if 'n' Allows for default value if return is entered.

Example usage if get_user_confirmation("Do you want to continue"): print("Continuing...") else: print("Exiting...")

Source code in src/tnh_scholar/utils/user_io_utils.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def get_user_confirmation(prompt: str, default: bool = True) -> bool:
    """
    Prompt the user for a yes/no confirmation with single-character input.
    Cross-platform implementation. Returns True if 'y' is entered, and False if 'n'
    Allows for default value if return is entered.

    Example usage
        if get_user_confirmation("Do you want to continue"):
            print("Continuing...")
        else:
            print("Exiting...")
    """
    print(f"{prompt} ", end="", flush=True)

    while True:
        char = get_single_char().lower()
        if char == "y":
            print(char)  # Echo the choice
            return True
        elif char == "n":
            print(char)
            return False
        elif char in ("\r", "\n"):  # Enter key (use default)
            print()  # Add a newline
            return default
        else:
            print(
                f"\nInvalid input: {char}. Please type 'y' or 'n': ", end="", flush=True
            )

iterate_subdir(directory, recursive=False)

Iterates through subdirectories in the given directory.

Parameters:

Name Type Description Default
directory Path

The root directory to start the iteration.

required
recursive bool

If True, iterates recursively through all subdirectories. If False, iterates only over the immediate subdirectories.

False

Yields:

Name Type Description
Path Path

Paths to each subdirectory.

Example

for subdir in iterate_subdir(Path('/root'), recursive=False): ... print(subdir)

Source code in src/tnh_scholar/utils/file_utils.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def iterate_subdir(
    directory: Path, recursive: bool = False
) -> Generator[Path, None, None]:
    """
    Iterates through subdirectories in the given directory.

    Args:
        directory (Path): The root directory to start the iteration.
        recursive (bool): If True, iterates recursively through all subdirectories.
                          If False, iterates only over the immediate subdirectories.

    Yields:
        Path: Paths to each subdirectory.

    Example:
        >>> for subdir in iterate_subdir(Path('/root'), recursive=False):
        ...     print(subdir)
    """
    if recursive:
        for subdirectory in directory.rglob("*"):
            if subdirectory.is_dir():
                yield subdirectory
    else:
        for subdirectory in directory.iterdir():
            if subdirectory.is_dir():
                yield subdirectory

load_json_into_model(file, model)

Loads a JSON file and validates it against a Pydantic model.

Parameters:

Name Type Description Default
file Path

Path to the JSON file.

required
model type[BaseModel]

The Pydantic model to validate against.

required

Returns:

Name Type Description
BaseModel BaseModel

An instance of the validated Pydantic model.

Raises:

Type Description
ValueError

If the file content is invalid JSON or does not match the model.

Example: class ExampleModel(BaseModel): name: str age: int city: str

if __name__ == "__main__":
    json_file = Path("example.json")
    try:
        data = load_json_into_model(json_file, ExampleModel)
        print(data)
    except ValueError as e:
        print(e)
Source code in src/tnh_scholar/utils/json_utils.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def load_json_into_model(file: Path, model: type[BaseModel]) -> BaseModel:
    """
    Loads a JSON file and validates it against a Pydantic model.

    Args:
        file (Path): Path to the JSON file.
        model (type[BaseModel]): The Pydantic model to validate against.

    Returns:
        BaseModel: An instance of the validated Pydantic model.

    Raises:
        ValueError: If the file content is invalid JSON or does not match the model.
    Example:
        class ExampleModel(BaseModel):
        name: str
        age: int
        city: str

        if __name__ == "__main__":
            json_file = Path("example.json")
            try:
                data = load_json_into_model(json_file, ExampleModel)
                print(data)
            except ValueError as e:
                print(e)
    """
    try:
        with file.open("r", encoding="utf-8") as f:
            data = json.load(f)
        return model(**data)
    except (json.JSONDecodeError, ValidationError) as e:
        raise ValueError(f"Error loading or validating JSON file '{file}': {e}") from e

load_jsonl_to_dict(file_path)

Load a JSONL file into a list of dictionaries.

Parameters:

Name Type Description Default
file_path Path

Path to the JSONL file.

required

Returns:

Type Description
List[Dict]

List[Dict]: A list of dictionaries, each representing a line in the JSONL file.

Example

from pathlib import Path file_path = Path("data.jsonl") data = load_jsonl_to_dict(file_path) print(data) [{'key1': 'value1'}, {'key2': 'value2'}]

Source code in src/tnh_scholar/utils/json_utils.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def load_jsonl_to_dict(file_path: Path) -> List[Dict]:
    """
    Load a JSONL file into a list of dictionaries.

    Args:
        file_path (Path): Path to the JSONL file.

    Returns:
        List[Dict]: A list of dictionaries, each representing a line in the JSONL file.

    Example:
        >>> from pathlib import Path
        >>> file_path = Path("data.jsonl")
        >>> data = load_jsonl_to_dict(file_path)
        >>> print(data)
        [{'key1': 'value1'}, {'key2': 'value2'}]
    """
    with file_path.open("r", encoding="utf-8") as file:
        return [json.loads(line.strip()) for line in file if line.strip()]

path_as_str(path)

Source code in src/tnh_scholar/utils/file_utils.py
243
244
def path_as_str(path: Path) -> str:
    return str(path.resolve())

read_str_from_file(file_path)

Reads the entire content of a text file.

Parameters:

Name Type Description Default
file_path Path

The path to the text file.

required

Returns:

Type Description
str

The content of the text file as a single string.

Source code in src/tnh_scholar/utils/file_utils.py
156
157
158
159
160
161
162
163
164
165
166
167
def read_str_from_file(file_path: Path) -> str:
    """Reads the entire content of a text file.

    Args:
        file_path: The path to the text file.

    Returns:
        The content of the text file as a single string.
    """

    with open(file_path, "r", encoding="utf-8") as file:
        return file.read()

sanitize_filename(filename, max_length=DEFAULT_MAX_FILENAME_LENGTH)

Sanitize filename for use unix use.

Source code in src/tnh_scholar/utils/file_utils.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def sanitize_filename(
    filename: str, 
    max_length: int = DEFAULT_MAX_FILENAME_LENGTH
    ) -> str:  
    """Sanitize filename for use unix use."""

    # Normalize Unicode to remove accents and convert to ASCII
    clean = (
        unicodedata.normalize(
            "NFKD", 
            filename).encode(
                "ascii", 
                "ignore").decode("ascii")
    )

    clean = clean.lower()
    clean = re.sub(r"[^a-z0-9\s]", " ", clean.strip())
    clean = clean.strip()

    # shorten
    clean = clean[:max_length].strip() 

    # convert spaces to _
    clean = re.sub(r"\s+", "_", clean)

    return clean

save_model_to_json(file, model, indent=4, ensure_ascii=False)

Saves a Pydantic model to a JSON file, formatted with indentation for readability.

Parameters:

Name Type Description Default
file Path

Path to the JSON file where the model will be saved.

required
model BaseModel

The Pydantic model instance to save.

required
indent int

Number of spaces for JSON indentation. Defaults to 4.

4
ensure_ascii bool

Whether to escape non-ASCII characters. Defaults to False.

False

Raises:

Type Description
ValueError

If the model cannot be serialized to JSON.

IOError

If there is an issue writing to the file.

Example

class ExampleModel(BaseModel): name: str age: int

if name == "main": model_instance = ExampleModel(name="John", age=30) json_file = Path("example.json") try: save_model_to_json(json_file, model_instance) print(f"Model saved to {json_file}") except (ValueError, IOError) as e: print(e)

Source code in src/tnh_scholar/utils/json_utils.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def save_model_to_json(
    file: Path, model: BaseModel, indent: int = 4, ensure_ascii: bool = False
) -> None:
    """
    Saves a Pydantic model to a JSON file, formatted with indentation for readability.

    Args:
        file (Path): Path to the JSON file where the model will be saved.
        model (BaseModel): The Pydantic model instance to save.
        indent (int): Number of spaces for JSON indentation. Defaults to 4.
        ensure_ascii (bool): Whether to escape non-ASCII characters. Defaults to False.

    Raises:
        ValueError: If the model cannot be serialized to JSON.
        IOError: If there is an issue writing to the file.

    Example:
        class ExampleModel(BaseModel):
            name: str
            age: int

        if __name__ == "__main__":
            model_instance = ExampleModel(name="John", age=30)
            json_file = Path("example.json")
            try:
                save_model_to_json(json_file, model_instance)
                print(f"Model saved to {json_file}")
            except (ValueError, IOError) as e:
                print(e)
    """
    try:
        # Serialize model to JSON string
        model_dict = model.model_dump()
    except TypeError as e:
        raise ValueError(f"Error serializing model to JSON: {e}") from e

    # Write the JSON string to the file
    write_data_to_json_file(file, model_dict, indent=indent, ensure_ascii=ensure_ascii)

to_slug(string)

Slugify a Unicode string.

Converts a string to a strict URL-friendly slug format, allowing only lowercase letters, digits, and hyphens.

Example

slugify("Héllø_Wörld!") 'hello-world'

Source code in src/tnh_scholar/utils/file_utils.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def to_slug(string: str) -> str:
    """
    Slugify a Unicode string.

    Converts a string to a strict URL-friendly slug format,
    allowing only lowercase letters, digits, and hyphens.

    Example:
        >>> slugify("Héllø_Wörld!")
        'hello-world'
    """
    # Normalize Unicode to remove accents and convert to ASCII
    string = (
        unicodedata.normalize("NFKD", string).encode("ascii", "ignore").decode("ascii")
    )

    # Replace all non-alphanumeric characters with spaces (only keep a-z and 0-9)
    string = re.sub(r"[^a-z0-9\s]", " ", string.lower().strip())

    # Replace any sequence of spaces with a single hyphen
    return re.sub(r"\s+", "-", string)

write_str_to_file(file_path, text, overwrite=False)

Writes text to a file with file locking.

Parameters:

Name Type Description Default
file_path PathLike

The path to the file to write.

required
text str

The text to write to the file.

required
overwrite bool

Whether to overwrite the file if it exists.

False

Raises:

Type Description
FileExistsError

If the file exists and overwrite is False.

OSError

If there's an issue with file locking or writing.

Source code in src/tnh_scholar/utils/file_utils.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def write_str_to_file(file_path: PathLike, text: str, overwrite: bool = False):
    """Writes text to a file with file locking.

    Args:
        file_path: The path to the file to write.
        text: The text to write to the file.
        overwrite: Whether to overwrite the file if it exists.

    Raises:
        FileExistsError: If the file exists and overwrite is False.
        OSError: If there's an issue with file locking or writing.
    """
    file_path = Path(file_path)

    if file_path.exists() and not overwrite:
        raise FileExistsError(f"File already exists: {file_path}")

    try:
        with file_path.open("w", encoding="utf-8") as f:
            fcntl.flock(f, fcntl.LOCK_EX)
            f.write(text)
            fcntl.flock(f, fcntl.LOCK_UN)  # Release lock
    except OSError as e:
        raise OSError(f"Error writing to or locking file {file_path}: {e}") from e

file_utils

DEFAULT_MAX_FILENAME_LENGTH = 25 module-attribute
PathLike = Union[str, Path] module-attribute
__all__ = ['DEFAULT_MAX_FILENAME_LENGTH', 'FileExistsWarning', 'ensure_directory_exists', 'ensure_directory_writable', 'iterate_subdir', 'path_source_str', 'copy_files_with_regex', 'read_str_from_file', 'write_str_to_file', 'sanitize_filename', 'to_slug', 'path_as_str'] module-attribute
FileExistsWarning

Bases: UserWarning

Source code in src/tnh_scholar/utils/file_utils.py
12
13
class FileExistsWarning(UserWarning):
    pass
copy_files_with_regex(source_dir, destination_dir, regex_patterns, preserve_structure=True)

Copies files from subdirectories one level down in the source directory to the destination directory if they match any regex pattern. Optionally preserves the directory structure.

Parameters:

Name Type Description Default
source_dir Path

Path to the source directory to search files in.

required
destination_dir Path

Path to the destination directory where files will be copied.

required
regex_patterns list[str]

List of regex patterns to match file names.

required
preserve_structure bool

Whether to preserve the directory structure. Defaults to True.

True

Raises:

Type Description
ValueError

If the source directory does not exist or is not a directory.

Example

copy_files_with_regex( ... source_dir=Path("/path/to/source"), ... destination_dir=Path("/path/to/destination"), ... regex_patterns=[r'..txt$', r'..log$'], ... preserve_structure=True ... )

Source code in src/tnh_scholar/utils/file_utils.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def copy_files_with_regex(
    source_dir: Path,
    destination_dir: Path,
    regex_patterns: list[str],
    preserve_structure: bool = True,
) -> None:
    """
    Copies files from subdirectories one level down in the source directory to 
    the destination directory if they match any regex pattern. Optionally preserves the 
    directory structure.

    Args:
        source_dir (Path): Path to the source directory to search files in.
        destination_dir (Path): Path to the destination directory where files will be 
            copied.
        regex_patterns (list[str]): List of regex patterns to match file names.
        preserve_structure (bool): Whether to preserve the directory structure. 
            Defaults to True.

    Raises:
        ValueError: If the source directory does not exist or is not a directory.

    Example:
        >>> copy_files_with_regex(
        ...     source_dir=Path("/path/to/source"),
        ...     destination_dir=Path("/path/to/destination"),
        ...     regex_patterns=[r'.*\\.txt$', r'.*\\.log$'],
        ...     preserve_structure=True
        ... )
    """
    if not source_dir.is_dir():
        raise ValueError(
            f"The source directory {source_dir} does not exist or is not a directory."
        )

    if not destination_dir.exists():
        destination_dir.mkdir(parents=True, exist_ok=True)

    # Compile regex patterns for efficiency
    compiled_patterns = [re.compile(pattern) for pattern in regex_patterns]

    # Process only one level down
    for subdir in source_dir.iterdir():
        if subdir.is_dir():  # Only process subdirectories
            print(f"processing {subdir}:")
            for file_path in subdir.iterdir():  # Only files in this subdirectory
                if file_path.is_file():
                    print(f"checking file: {file_path.name}")
                    # Check if the file matches any of the regex patterns
                    if any(
                        pattern.match(file_path.name) for pattern in compiled_patterns
                    ):
                        if preserve_structure:
                            # Construct the target path, preserving relative structure
                            relative_path = (
                                subdir.relative_to(source_dir) / file_path.name
                            )
                            target_path = destination_dir / relative_path
                            target_path.parent.mkdir(parents=True, exist_ok=True)
                        else:
                            # Put directly in destination without subdirectory structure
                            target_path = destination_dir / file_path.name

                        shutil.copy2(file_path, target_path)
                        print(f"Copied: {file_path} -> {target_path}")
ensure_directory_exists(dir_path)

Create directory if it doesn't exist.

Parameters:

Name Type Description Default
dir_path Path

Directory path to ensure exists.

required

Returns:

Name Type Description
bool bool

True if the directory exists or was created successfully, False otherwise.

Source code in src/tnh_scholar/utils/file_utils.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def ensure_directory_exists(dir_path: Path) -> bool:
    """
    Create directory if it doesn't exist.

    Args:
        dir_path (Path): Directory path to ensure exists.

    Returns:
        bool: True if the directory exists or was created successfully, False otherwise.
    """
    # No exception handling here. 
    # If exceptions occur let them propagate. 
    # Prototype code.

    dir_path.mkdir(parents=True, exist_ok=True)
    return True
ensure_directory_writable(dir_path)

Ensure the directory exists and is writable. Creates the directory if it does not exist.

Parameters:

Name Type Description Default
dir_path Path

Directory to verify or create.

required

Raises:

Type Description
ValueError

If the directory cannot be created or is not writable.

TypeError

If the provided path is not a Path instance.

Source code in src/tnh_scholar/utils/file_utils.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def ensure_directory_writable(dir_path: Path) -> None:
    """
    Ensure the directory exists and is writable.
    Creates the directory if it does not exist.

    Args:
        dir_path (Path): Directory to verify or create.

    Raises:
        ValueError: If the directory cannot be created or is not writable.
        TypeError: If the provided path is not a Path instance.
    """
    if not isinstance(dir_path, Path):
        raise TypeError("dir_path must be a pathlib.Path instance")

    # Ensure directory exists first
    ensure_directory_exists(dir_path)

    # Check writability safely using NamedTemporaryFile
    try:
        with tempfile.NamedTemporaryFile(dir=dir_path, prefix=".writability_check_", delete=True) as tmp:
            tmp.write(b"test")
            tmp.flush()
    except Exception as e:
        raise ValueError(f"Directory is not writable: {dir_path}") from e
iterate_subdir(directory, recursive=False)

Iterates through subdirectories in the given directory.

Parameters:

Name Type Description Default
directory Path

The root directory to start the iteration.

required
recursive bool

If True, iterates recursively through all subdirectories. If False, iterates only over the immediate subdirectories.

False

Yields:

Name Type Description
Path Path

Paths to each subdirectory.

Example

for subdir in iterate_subdir(Path('/root'), recursive=False): ... print(subdir)

Source code in src/tnh_scholar/utils/file_utils.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def iterate_subdir(
    directory: Path, recursive: bool = False
) -> Generator[Path, None, None]:
    """
    Iterates through subdirectories in the given directory.

    Args:
        directory (Path): The root directory to start the iteration.
        recursive (bool): If True, iterates recursively through all subdirectories.
                          If False, iterates only over the immediate subdirectories.

    Yields:
        Path: Paths to each subdirectory.

    Example:
        >>> for subdir in iterate_subdir(Path('/root'), recursive=False):
        ...     print(subdir)
    """
    if recursive:
        for subdirectory in directory.rglob("*"):
            if subdirectory.is_dir():
                yield subdirectory
    else:
        for subdirectory in directory.iterdir():
            if subdirectory.is_dir():
                yield subdirectory
path_as_str(path)
Source code in src/tnh_scholar/utils/file_utils.py
243
244
def path_as_str(path: Path) -> str:
    return str(path.resolve())
path_source_str(path)
Source code in src/tnh_scholar/utils/file_utils.py
86
87
def path_source_str(path: Path):
    return str(path.resolve())
read_str_from_file(file_path)

Reads the entire content of a text file.

Parameters:

Name Type Description Default
file_path Path

The path to the text file.

required

Returns:

Type Description
str

The content of the text file as a single string.

Source code in src/tnh_scholar/utils/file_utils.py
156
157
158
159
160
161
162
163
164
165
166
167
def read_str_from_file(file_path: Path) -> str:
    """Reads the entire content of a text file.

    Args:
        file_path: The path to the text file.

    Returns:
        The content of the text file as a single string.
    """

    with open(file_path, "r", encoding="utf-8") as file:
        return file.read()
sanitize_filename(filename, max_length=DEFAULT_MAX_FILENAME_LENGTH)

Sanitize filename for use unix use.

Source code in src/tnh_scholar/utils/file_utils.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def sanitize_filename(
    filename: str, 
    max_length: int = DEFAULT_MAX_FILENAME_LENGTH
    ) -> str:  
    """Sanitize filename for use unix use."""

    # Normalize Unicode to remove accents and convert to ASCII
    clean = (
        unicodedata.normalize(
            "NFKD", 
            filename).encode(
                "ascii", 
                "ignore").decode("ascii")
    )

    clean = clean.lower()
    clean = re.sub(r"[^a-z0-9\s]", " ", clean.strip())
    clean = clean.strip()

    # shorten
    clean = clean[:max_length].strip() 

    # convert spaces to _
    clean = re.sub(r"\s+", "_", clean)

    return clean
to_slug(string)

Slugify a Unicode string.

Converts a string to a strict URL-friendly slug format, allowing only lowercase letters, digits, and hyphens.

Example

slugify("Héllø_Wörld!") 'hello-world'

Source code in src/tnh_scholar/utils/file_utils.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def to_slug(string: str) -> str:
    """
    Slugify a Unicode string.

    Converts a string to a strict URL-friendly slug format,
    allowing only lowercase letters, digits, and hyphens.

    Example:
        >>> slugify("Héllø_Wörld!")
        'hello-world'
    """
    # Normalize Unicode to remove accents and convert to ASCII
    string = (
        unicodedata.normalize("NFKD", string).encode("ascii", "ignore").decode("ascii")
    )

    # Replace all non-alphanumeric characters with spaces (only keep a-z and 0-9)
    string = re.sub(r"[^a-z0-9\s]", " ", string.lower().strip())

    # Replace any sequence of spaces with a single hyphen
    return re.sub(r"\s+", "-", string)
write_str_to_file(file_path, text, overwrite=False)

Writes text to a file with file locking.

Parameters:

Name Type Description Default
file_path PathLike

The path to the file to write.

required
text str

The text to write to the file.

required
overwrite bool

Whether to overwrite the file if it exists.

False

Raises:

Type Description
FileExistsError

If the file exists and overwrite is False.

OSError

If there's an issue with file locking or writing.

Source code in src/tnh_scholar/utils/file_utils.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def write_str_to_file(file_path: PathLike, text: str, overwrite: bool = False):
    """Writes text to a file with file locking.

    Args:
        file_path: The path to the file to write.
        text: The text to write to the file.
        overwrite: Whether to overwrite the file if it exists.

    Raises:
        FileExistsError: If the file exists and overwrite is False.
        OSError: If there's an issue with file locking or writing.
    """
    file_path = Path(file_path)

    if file_path.exists() and not overwrite:
        raise FileExistsError(f"File already exists: {file_path}")

    try:
        with file_path.open("w", encoding="utf-8") as f:
            fcntl.flock(f, fcntl.LOCK_EX)
            f.write(text)
            fcntl.flock(f, fcntl.LOCK_UN)  # Release lock
    except OSError as e:
        raise OSError(f"Error writing to or locking file {file_path}: {e}") from e

json_utils

format_json(file)

Formats a JSON file with line breaks and indentation for readability.

Parameters:

Name Type Description Default
file Path

Path to the JSON file to be formatted.

required
Example

format_json(Path("data.json"))

Source code in src/tnh_scholar/utils/json_utils.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def format_json(file: Path) -> None:
    """
    Formats a JSON file with line breaks and indentation for readability.

    Args:
        file (Path): Path to the JSON file to be formatted.

    Example:
        format_json(Path("data.json"))
    """
    with file.open("r", encoding="utf-8") as f:
        data = json.load(f)

    with file.open("w", encoding="utf-8") as f:
        json.dump(data, f, indent=4, ensure_ascii=False)
load_json_into_model(file, model)

Loads a JSON file and validates it against a Pydantic model.

Parameters:

Name Type Description Default
file Path

Path to the JSON file.

required
model type[BaseModel]

The Pydantic model to validate against.

required

Returns:

Name Type Description
BaseModel BaseModel

An instance of the validated Pydantic model.

Raises:

Type Description
ValueError

If the file content is invalid JSON or does not match the model.

Example: class ExampleModel(BaseModel): name: str age: int city: str

if __name__ == "__main__":
    json_file = Path("example.json")
    try:
        data = load_json_into_model(json_file, ExampleModel)
        print(data)
    except ValueError as e:
        print(e)
Source code in src/tnh_scholar/utils/json_utils.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def load_json_into_model(file: Path, model: type[BaseModel]) -> BaseModel:
    """
    Loads a JSON file and validates it against a Pydantic model.

    Args:
        file (Path): Path to the JSON file.
        model (type[BaseModel]): The Pydantic model to validate against.

    Returns:
        BaseModel: An instance of the validated Pydantic model.

    Raises:
        ValueError: If the file content is invalid JSON or does not match the model.
    Example:
        class ExampleModel(BaseModel):
        name: str
        age: int
        city: str

        if __name__ == "__main__":
            json_file = Path("example.json")
            try:
                data = load_json_into_model(json_file, ExampleModel)
                print(data)
            except ValueError as e:
                print(e)
    """
    try:
        with file.open("r", encoding="utf-8") as f:
            data = json.load(f)
        return model(**data)
    except (json.JSONDecodeError, ValidationError) as e:
        raise ValueError(f"Error loading or validating JSON file '{file}': {e}") from e
load_jsonl_to_dict(file_path)

Load a JSONL file into a list of dictionaries.

Parameters:

Name Type Description Default
file_path Path

Path to the JSONL file.

required

Returns:

Type Description
List[Dict]

List[Dict]: A list of dictionaries, each representing a line in the JSONL file.

Example

from pathlib import Path file_path = Path("data.jsonl") data = load_jsonl_to_dict(file_path) print(data) [{'key1': 'value1'}, {'key2': 'value2'}]

Source code in src/tnh_scholar/utils/json_utils.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def load_jsonl_to_dict(file_path: Path) -> List[Dict]:
    """
    Load a JSONL file into a list of dictionaries.

    Args:
        file_path (Path): Path to the JSONL file.

    Returns:
        List[Dict]: A list of dictionaries, each representing a line in the JSONL file.

    Example:
        >>> from pathlib import Path
        >>> file_path = Path("data.jsonl")
        >>> data = load_jsonl_to_dict(file_path)
        >>> print(data)
        [{'key1': 'value1'}, {'key2': 'value2'}]
    """
    with file_path.open("r", encoding="utf-8") as file:
        return [json.loads(line.strip()) for line in file if line.strip()]
save_model_to_json(file, model, indent=4, ensure_ascii=False)

Saves a Pydantic model to a JSON file, formatted with indentation for readability.

Parameters:

Name Type Description Default
file Path

Path to the JSON file where the model will be saved.

required
model BaseModel

The Pydantic model instance to save.

required
indent int

Number of spaces for JSON indentation. Defaults to 4.

4
ensure_ascii bool

Whether to escape non-ASCII characters. Defaults to False.

False

Raises:

Type Description
ValueError

If the model cannot be serialized to JSON.

IOError

If there is an issue writing to the file.

Example

class ExampleModel(BaseModel): name: str age: int

if name == "main": model_instance = ExampleModel(name="John", age=30) json_file = Path("example.json") try: save_model_to_json(json_file, model_instance) print(f"Model saved to {json_file}") except (ValueError, IOError) as e: print(e)

Source code in src/tnh_scholar/utils/json_utils.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def save_model_to_json(
    file: Path, model: BaseModel, indent: int = 4, ensure_ascii: bool = False
) -> None:
    """
    Saves a Pydantic model to a JSON file, formatted with indentation for readability.

    Args:
        file (Path): Path to the JSON file where the model will be saved.
        model (BaseModel): The Pydantic model instance to save.
        indent (int): Number of spaces for JSON indentation. Defaults to 4.
        ensure_ascii (bool): Whether to escape non-ASCII characters. Defaults to False.

    Raises:
        ValueError: If the model cannot be serialized to JSON.
        IOError: If there is an issue writing to the file.

    Example:
        class ExampleModel(BaseModel):
            name: str
            age: int

        if __name__ == "__main__":
            model_instance = ExampleModel(name="John", age=30)
            json_file = Path("example.json")
            try:
                save_model_to_json(json_file, model_instance)
                print(f"Model saved to {json_file}")
            except (ValueError, IOError) as e:
                print(e)
    """
    try:
        # Serialize model to JSON string
        model_dict = model.model_dump()
    except TypeError as e:
        raise ValueError(f"Error serializing model to JSON: {e}") from e

    # Write the JSON string to the file
    write_data_to_json_file(file, model_dict, indent=indent, ensure_ascii=ensure_ascii)
write_data_to_json_file(file, data, indent=4, ensure_ascii=False)

Writes a dictionary or list as a JSON string to a file, ensuring the parent directory exists, and supports formatting with indentation and ASCII control.

Parameters:

Name Type Description Default
file Path

Path to the JSON file where the data will be written.

required
data Union[dict, list]

The data to write to the file. Typically a dict or list.

required
indent int

Number of spaces for JSON indentation. Defaults to 4.

4
ensure_ascii bool

Whether to escape non-ASCII characters. Defaults to False.

False

Raises:

Type Description
ValueError

If the data cannot be serialized to JSON.

IOError

If there is an issue writing to the file.

Example

from pathlib import Path data = {"key": "value"} write_json_str_to_file(Path("output.json"), data, indent=2, ensure_ascii=True)

Source code in src/tnh_scholar/utils/json_utils.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def write_data_to_json_file(
    file: Path, data: Union[dict, list], indent: int = 4, ensure_ascii: bool = False
) -> None:
    """
    Writes a dictionary or list as a JSON string to a file, 
    ensuring the parent directory exists,
    and supports formatting with indentation and ASCII control.

    Args:
        file (Path): Path to the JSON file where the data will be written.
        data (Union[dict, list]): The data to write to the file. Typically a dict or list.
        indent (int): Number of spaces for JSON indentation. Defaults to 4.
        ensure_ascii (bool): Whether to escape non-ASCII characters. Defaults to False.

    Raises:
        ValueError: If the data cannot be serialized to JSON.
        IOError: If there is an issue writing to the file.

    Example:
        >>> from pathlib import Path
        >>> data = {"key": "value"}
        >>> write_json_str_to_file(Path("output.json"), data, indent=2, ensure_ascii=True)
    """
    try:
        # Convert the data to a formatted JSON string
        json_str = json.dumps(data, indent=indent, ensure_ascii=ensure_ascii)
    except TypeError as e:
        raise ValueError(f"Error serializing data to JSON: {e}") from e

    try:
        # Ensure the parent directory exists
        file.parent.mkdir(parents=True, exist_ok=True)

        # Write the JSON string to the file
        with file.open("w", encoding="utf-8") as f:
            f.write(json_str)
    except IOError as e:
        raise IOError(f"Error writing JSON string to file '{file}': {e}") from e

lang

logger = get_child_logger(__name__) module-attribute
get_language_code_from_text(text)

Detect the language of the provided text using langdetect.

Parameters:

Name Type Description Default
text str

Text to analyze

      code or 'name' for full English language name
required

Returns:

Name Type Description
str str

return result 'code' ISO 639-1 for detected language.

Raises:

Type Description
ValueError

If text is empty or invalid

Source code in src/tnh_scholar/utils/lang.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def get_language_code_from_text(text: str) -> str:
    """
    Detect the language of the provided text using langdetect.

    Args:
        text: Text to analyze

                      code or 'name' for full English language name

    Returns:
        str: return result 'code' ISO 639-1 for detected language.

    Raises:
        ValueError: If text is empty or invalid
    """

    if not text or text.isspace():
        raise ValueError("Input text cannot be empty")

    sample = _get_sample_text(text)

    try:
        return detect(sample)
    except LangDetectException:
        logger.warning("Language could not be detected in get_language().")
        return "un"
get_language_from_code(code)
Source code in src/tnh_scholar/utils/lang.py
40
41
42
43
44
def get_language_from_code(code: str):
    if language := pycountry.languages.get(alpha_2=code):
        return language.name
    logger.warning(f"No language name found for code: {code}")
    return "Unknown"
get_language_name_from_text(text)
Source code in src/tnh_scholar/utils/lang.py
36
37
def get_language_name_from_text(text: str) -> str:
    return get_language_from_code(get_language_code_from_text(text))

progress_utils

BAR_FORMAT = '{desc}: {percentage:3.0f}%|{bar}| Total: {total_fmt} sec. [elapsed: {elapsed}]' module-attribute
ExpectedTimeTQDM

A context manager for a time-based tqdm progress bar with optional delay.

  • 'expected_time': number of seconds we anticipate the task might take.
  • 'display_interval': how often (seconds) to refresh the bar.
  • 'desc': a short description for the bar.
  • 'delay_start': how many seconds to wait (sleep) before we even create/start the bar.

If the task finishes before 'delay_start' has elapsed, the bar may never appear.

Source code in src/tnh_scholar/utils/progress_utils.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class ExpectedTimeTQDM:
    """
    A context manager for a time-based tqdm progress bar with optional delay.

    - 'expected_time': number of seconds we anticipate the task might take.
    - 'display_interval': how often (seconds) to refresh the bar.
    - 'desc': a short description for the bar.
    - 'delay_start': how many seconds to wait (sleep) before we even create/start the bar.

    If the task finishes before 'delay_start' has elapsed, the bar may never appear.
    """

    def __init__(
        self,
        expected_time: float,
        display_interval: float = 0.5,
        desc: str = "Time-based Progress",
        delay_start: float = 1.0,
    ) -> None:
        self.expected_time = round(expected_time)  # use nearest second.
        self.display_interval = display_interval
        self.desc = desc
        self.delay_start = delay_start

        self._stop_event = threading.Event()
        self._pbar = None  # We won't create the bar until after 'delay_start'
        self._start_time = None

    def __enter__(self):
        # Record the start time for reference
        self._start_time = time.time()

        # Spawn the background thread; it will handle waiting and then creating/updating the bar
        self._thread = threading.Thread(target=self._update_bar, daemon=True)
        self._thread.start()

        return self

    def _update_bar(self):
        # 1) Delay so warnings/logs can appear before the bar
        if self.delay_start > 0:
            time.sleep(self.delay_start)

        # 2) Create the tqdm bar (only now does it appear)
        self._pbar = tqdm(
            total=self.expected_time, desc=self.desc, unit="sec", bar_format=BAR_FORMAT
        )

        # 3) Update until told to stop
        while not self._stop_event.is_set():
            elapsed = time.time() - self._start_time
            current_value = min(elapsed, self.expected_time)
            if self._pbar:
                self._pbar.n = round(current_value)
                self._pbar.refresh()
            time.sleep(self.display_interval)

    def __exit__(self, exc_type, exc_value, traceback):
        # Signal the thread to stop
        self._stop_event.set()
        self._thread.join()

        # If the bar was actually created (i.e., we didn't finish too quickly),
        # do a final update and close
        if self._pbar:
            elapsed = time.time() - self._start_time
            self._pbar.n = round(min(elapsed, self.expected_time))
            self._pbar.refresh()
            self._pbar.close()

    import time
delay_start = delay_start instance-attribute
desc = desc instance-attribute
display_interval = display_interval instance-attribute
expected_time = round(expected_time) instance-attribute
__enter__()
Source code in src/tnh_scholar/utils/progress_utils.py
41
42
43
44
45
46
47
48
49
def __enter__(self):
    # Record the start time for reference
    self._start_time = time.time()

    # Spawn the background thread; it will handle waiting and then creating/updating the bar
    self._thread = threading.Thread(target=self._update_bar, daemon=True)
    self._thread.start()

    return self
__exit__(exc_type, exc_value, traceback)
Source code in src/tnh_scholar/utils/progress_utils.py
70
71
72
73
74
75
76
77
78
79
80
81
def __exit__(self, exc_type, exc_value, traceback):
    # Signal the thread to stop
    self._stop_event.set()
    self._thread.join()

    # If the bar was actually created (i.e., we didn't finish too quickly),
    # do a final update and close
    if self._pbar:
        elapsed = time.time() - self._start_time
        self._pbar.n = round(min(elapsed, self.expected_time))
        self._pbar.refresh()
        self._pbar.close()
__init__(expected_time, display_interval=0.5, desc='Time-based Progress', delay_start=1.0)
Source code in src/tnh_scholar/utils/progress_utils.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    expected_time: float,
    display_interval: float = 0.5,
    desc: str = "Time-based Progress",
    delay_start: float = 1.0,
) -> None:
    self.expected_time = round(expected_time)  # use nearest second.
    self.display_interval = display_interval
    self.desc = desc
    self.delay_start = delay_start

    self._stop_event = threading.Event()
    self._pbar = None  # We won't create the bar until after 'delay_start'
    self._start_time = None
TimeProgress

A context manager for a time-based progress display using dots.

The display updates once per second, printing a dot and showing: - Expected time (if provided) - Elapsed time (always displayed)

Example:

import time with ExpectedTimeProgress(expected_time=60, desc="Transcribing..."): ... time.sleep(5) # Simulate work [Expected Time: 1:00, Elapsed Time: 0:05] .....

Parameters:

Name Type Description Default
expected_time Optional[float]

Expected time in seconds. Optional.

None
display_interval float

How often to print a dot (seconds).

1.0
desc str

Description to display alongside the progress.

''
Source code in src/tnh_scholar/utils/progress_utils.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class TimeProgress:
    """
    A context manager for a time-based progress display using dots.

    The display updates once per second, printing a dot and showing:
    - Expected time (if provided)
    - Elapsed time (always displayed)

    Example:
    >>> import time
    >>> with ExpectedTimeProgress(expected_time=60, desc="Transcribing..."):
    ...     time.sleep(5)  # Simulate work
    [Expected Time: 1:00, Elapsed Time: 0:05] .....

    Args:
        expected_time (Optional[float]): Expected time in seconds. Optional.
        display_interval (float): How often to print a dot (seconds).
        desc (str): Description to display alongside the progress.
    """

    def __init__(
        self,
        expected_time: Optional[float] = None,
        display_interval: float = 1.0,
        desc: str = "",
    ):
        self.expected_time = expected_time
        self.display_interval = display_interval
        self._stop_event = threading.Event()
        self._start_time = None
        self._thread = None
        self.desc = desc
        self._last_length = 0  # To keep track of the last printed line length

    def __enter__(self):
        # Record the start time
        self._start_time = time.time()

        # Spawn the background thread
        self._thread = threading.Thread(target=self._print_progress, daemon=True)
        self._thread.start()

        return self

    def _print_progress(self):
        """
        Continuously prints progress alternating between | and — along with elapsed/expected time.
        """
        symbols = ["|", "/", "—", "\\"]  # Symbols to alternate between
        symbol_index = 0  # Keep track of the current symbol

        while not self._stop_event.is_set():
            elapsed = time.time() - self._start_time

            # Format elapsed time as mm:ss
            elapsed_str = self._format_time(elapsed)

            # Format expected time if provided
            if self.expected_time is not None:
                expected_str = self._format_time(self.expected_time)
                header = f"{self.desc} [Expected Time: {expected_str}, Elapsed Time: {elapsed_str}]"
            else:
                header = f"{self.desc} [Elapsed Time: {elapsed_str}]"

            # Get the current symbol for the spinner
            spinner = symbols[symbol_index]

            # Construct the line with the spinner
            line = f"\r{header} {spinner}"

            # Write to stdout
            sys.stdout.write(line)
            sys.stdout.flush()

            # Update the symbol index to alternate
            symbol_index = (symbol_index + 1) % len(symbols)

            # Sleep before next update
            time.sleep(self.display_interval)

        # Clear the spinner after finishing
        sys.stdout.write("\r" + " " * len(line) + "\r")
        sys.stdout.flush()

    def __exit__(self, exc_type, exc_value, traceback):
        # Signal the thread to stop
        self._stop_event.set()
        self._thread.join()

        # Final elapsed time
        elapsed = time.time() - self._start_time
        elapsed_str = self._format_time(elapsed)

        # Construct the final line
        if self.expected_time is not None:
            expected_str = self._format_time(self.expected_time)
            final_header = f"{self.desc} [Expected Time: {expected_str}, Elapsed Time: {elapsed_str}]"
        else:
            final_header = f"{self.desc} [Elapsed Time: {elapsed_str}]"

        # Final dots
        final_line = f"\r{final_header}"

        # Clear the line and move to the next line
        padding = " " * max(self._last_length - len(final_line), 0)
        sys.stdout.write(final_line + padding + "\n")
        sys.stdout.flush()

    @staticmethod
    def _format_time(seconds: float) -> str:
        """
        Converts seconds to a formatted string (mm:ss).
        """
        minutes = int(seconds // 60)
        seconds = int(seconds % 60)
        return f"{minutes}:{seconds:02}"
desc = desc instance-attribute
display_interval = display_interval instance-attribute
expected_time = expected_time instance-attribute
__enter__()
Source code in src/tnh_scholar/utils/progress_utils.py
122
123
124
125
126
127
128
129
130
def __enter__(self):
    # Record the start time
    self._start_time = time.time()

    # Spawn the background thread
    self._thread = threading.Thread(target=self._print_progress, daemon=True)
    self._thread.start()

    return self
__exit__(exc_type, exc_value, traceback)
Source code in src/tnh_scholar/utils/progress_utils.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def __exit__(self, exc_type, exc_value, traceback):
    # Signal the thread to stop
    self._stop_event.set()
    self._thread.join()

    # Final elapsed time
    elapsed = time.time() - self._start_time
    elapsed_str = self._format_time(elapsed)

    # Construct the final line
    if self.expected_time is not None:
        expected_str = self._format_time(self.expected_time)
        final_header = f"{self.desc} [Expected Time: {expected_str}, Elapsed Time: {elapsed_str}]"
    else:
        final_header = f"{self.desc} [Elapsed Time: {elapsed_str}]"

    # Final dots
    final_line = f"\r{final_header}"

    # Clear the line and move to the next line
    padding = " " * max(self._last_length - len(final_line), 0)
    sys.stdout.write(final_line + padding + "\n")
    sys.stdout.flush()
__init__(expected_time=None, display_interval=1.0, desc='')
Source code in src/tnh_scholar/utils/progress_utils.py
108
109
110
111
112
113
114
115
116
117
118
119
120
def __init__(
    self,
    expected_time: Optional[float] = None,
    display_interval: float = 1.0,
    desc: str = "",
):
    self.expected_time = expected_time
    self.display_interval = display_interval
    self._stop_event = threading.Event()
    self._start_time = None
    self._thread = None
    self.desc = desc
    self._last_length = 0  # To keep track of the last printed line length

timing_utils

TimeMs

Bases: int

Lightweight representation of a time interval or timestamp in milliseconds. Allows negative values.

Source code in src/tnh_scholar/utils/timing_utils.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class TimeMs(int):
    """
    Lightweight representation of a time interval or timestamp in milliseconds.
    Allows negative values.
    """

    def __new__(cls, ms: Union[int, float, "TimeMs"]):
        if isinstance(ms, TimeMs):
            value = int(ms)
        elif isinstance(ms, (int, float)):
            if not math.isfinite(ms):
                raise ValueError("ms must be a finite number")
            value = round(ms)
        else:
            raise TypeError(f"ms must be a number or TimeMs, got {type(ms).__name__}")
        return int.__new__(cls, value)

    @classmethod
    def from_seconds(cls, seconds: int | float) -> "TimeMs":
        return cls(round(seconds * 1000))

    def to_ms(self) -> int:
        return int(self)

    def to_seconds(self) -> float:
        return float(self) / 1000

    @classmethod
    def __get_pydantic_core_schema__(cls, source_type, handler: GetCoreSchemaHandler):
        return core_schema.with_info_plain_validator_function(
            cls._validate,
            serialization=core_schema.plain_serializer_function_ser_schema(lambda v: int(v)),
        )

    @classmethod
    def _validate(cls, value, info):
        """
        Pydantic core validator for TimeMs.

        Args:
            value: The value to validate.
            info: Pydantic core schema info (unused).

        Returns:
            TimeMs: Validated TimeMs instance.
        """
        return cls(value)

    def __add__(self, other):
        return TimeMs(int(self) + int(other))

    def __radd__(self, other):
        return TimeMs(int(other) + int(self))

    def __sub__(self, other):
        return TimeMs(int(self) - int(other))

    def __rsub__(self, other):
        return TimeMs(int(self) - int(other))

    def __repr__(self) -> str:
        return f"TimeMs({self.to_seconds():.3f}s)"
__add__(other)
Source code in src/tnh_scholar/utils/timing_utils.py
62
63
def __add__(self, other):
    return TimeMs(int(self) + int(other))
__get_pydantic_core_schema__(source_type, handler) classmethod
Source code in src/tnh_scholar/utils/timing_utils.py
41
42
43
44
45
46
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler: GetCoreSchemaHandler):
    return core_schema.with_info_plain_validator_function(
        cls._validate,
        serialization=core_schema.plain_serializer_function_ser_schema(lambda v: int(v)),
    )
__new__(ms)
Source code in src/tnh_scholar/utils/timing_utils.py
20
21
22
23
24
25
26
27
28
29
def __new__(cls, ms: Union[int, float, "TimeMs"]):
    if isinstance(ms, TimeMs):
        value = int(ms)
    elif isinstance(ms, (int, float)):
        if not math.isfinite(ms):
            raise ValueError("ms must be a finite number")
        value = round(ms)
    else:
        raise TypeError(f"ms must be a number or TimeMs, got {type(ms).__name__}")
    return int.__new__(cls, value)
__radd__(other)
Source code in src/tnh_scholar/utils/timing_utils.py
65
66
def __radd__(self, other):
    return TimeMs(int(other) + int(self))
__repr__()
Source code in src/tnh_scholar/utils/timing_utils.py
74
75
def __repr__(self) -> str:
    return f"TimeMs({self.to_seconds():.3f}s)"
__rsub__(other)
Source code in src/tnh_scholar/utils/timing_utils.py
71
72
def __rsub__(self, other):
    return TimeMs(int(self) - int(other))
__sub__(other)
Source code in src/tnh_scholar/utils/timing_utils.py
68
69
def __sub__(self, other):
    return TimeMs(int(self) - int(other))
from_seconds(seconds) classmethod
Source code in src/tnh_scholar/utils/timing_utils.py
31
32
33
@classmethod
def from_seconds(cls, seconds: int | float) -> "TimeMs":
    return cls(round(seconds * 1000))
to_ms()
Source code in src/tnh_scholar/utils/timing_utils.py
35
36
def to_ms(self) -> int:
    return int(self)
to_seconds()
Source code in src/tnh_scholar/utils/timing_utils.py
38
39
def to_seconds(self) -> float:
    return float(self) / 1000
convert_ms_to_sec(ms)

Convert time from milliseconds (int) to seconds (float).

Source code in src/tnh_scholar/utils/timing_utils.py
83
84
85
def convert_ms_to_sec(ms: int) -> float:
    """Convert time from milliseconds (int) to seconds (float)."""
    return float(ms / 1000)
convert_sec_to_ms(val)

Convert seconds to milliseconds, rounding to the nearest integer.

Source code in src/tnh_scholar/utils/timing_utils.py
77
78
79
80
81
def convert_sec_to_ms(val: float) -> int:
    """ 
    Convert seconds to milliseconds, rounding to the nearest integer.
    """
    return round(val * 1000)

tnh_audio_segment

TNHAudioSegment: A typed, minimal wrapper for pydub.AudioSegment.

This class provides a type-safe interface for working with audio segments using pydub, enabling easier composition, slicing, and manipulation of audio data. It exposes common operations such as concatenation, slicing, and length retrieval, while hiding the underlying pydub implementation.

Key features
  • Type-annotated methods for static analysis and IDE support
  • Static constructors for silent and empty segments
  • Operator overloads for concatenation and slicing
  • Access to the underlying pydub.AudioSegment via the raw property

Extend this class with additional methods as needed for your audio processing workflows.

TNHAudioSegment
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class TNHAudioSegment:
    def __init__(self, segment: _AudioSegment):
        self._segment = segment

    @staticmethod
    def from_file(file: str | Path | BytesIO, format: str | None = None, **kwargs) -> "TNHAudioSegment":
        """
        Wrapper: Load an audio file into a TNHAudioSegment.

        Args:
            file: Path to the audio file.
            format: Optional audio format (e.g., 'mp3', 'wav'). If None, pydub will attempt to infer it.
            **kwargs: Additional keyword arguments passed to pydub.AudioSegment.from_file.

        Returns:
            TNHAudioSegment instance containing the loaded audio.
        """
        return TNHAudioSegment(_AudioSegment.from_file(file, format=format, **kwargs))

    def export(self, out_f: str | BinaryIO, format: str, **kwargs) -> None:
        """
        Wrapper: Export the audio segment to a file-like object or file path.

        Args:
            out_f: File path or file-like object to write the audio data to.
            format: Audio format (e.g., 'mp3', 'wav').
            **kwargs: Additional keyword arguments passed to pydub.AudioSegment.export.
        """
        self._segment.export(out_f, format=format, **kwargs)

    @staticmethod
    def silent(duration: int) -> "TNHAudioSegment":
        return TNHAudioSegment(_AudioSegment.silent(duration=duration))

    @staticmethod
    def empty() -> "TNHAudioSegment":
        return TNHAudioSegment(_AudioSegment.empty())

    def __getitem__(self, key: int | slice) -> "TNHAudioSegment":
        return TNHAudioSegment(self._segment[key]) # type: ignore

    def __add__(self, other: "TNHAudioSegment") -> "TNHAudioSegment":
        return TNHAudioSegment(self._segment + other._segment)

    def __iadd__(self, other: "TNHAudioSegment") -> "TNHAudioSegment":
        self._segment = self._segment + other._segment
        return self

    def __len__(self) -> int:
        return len(self._segment)

    # Add more methods as needed, e.g., export, from_file, etc.

    @property
    def raw(self) -> _AudioSegment:
        """Access the underlying pydub.AudioSegment if needed."""
        return self._segment
raw property

Access the underlying pydub.AudioSegment if needed.

__add__(other)
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
66
67
def __add__(self, other: "TNHAudioSegment") -> "TNHAudioSegment":
    return TNHAudioSegment(self._segment + other._segment)
__getitem__(key)
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
63
64
def __getitem__(self, key: int | slice) -> "TNHAudioSegment":
    return TNHAudioSegment(self._segment[key]) # type: ignore
__iadd__(other)
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
69
70
71
def __iadd__(self, other: "TNHAudioSegment") -> "TNHAudioSegment":
    self._segment = self._segment + other._segment
    return self
__init__(segment)
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
26
27
def __init__(self, segment: _AudioSegment):
    self._segment = segment
__len__()
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
73
74
def __len__(self) -> int:
    return len(self._segment)
empty() staticmethod
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
59
60
61
@staticmethod
def empty() -> "TNHAudioSegment":
    return TNHAudioSegment(_AudioSegment.empty())
export(out_f, format, **kwargs)

Wrapper: Export the audio segment to a file-like object or file path.

Parameters:

Name Type Description Default
out_f str | BinaryIO

File path or file-like object to write the audio data to.

required
format str

Audio format (e.g., 'mp3', 'wav').

required
**kwargs

Additional keyword arguments passed to pydub.AudioSegment.export.

{}
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
44
45
46
47
48
49
50
51
52
53
def export(self, out_f: str | BinaryIO, format: str, **kwargs) -> None:
    """
    Wrapper: Export the audio segment to a file-like object or file path.

    Args:
        out_f: File path or file-like object to write the audio data to.
        format: Audio format (e.g., 'mp3', 'wav').
        **kwargs: Additional keyword arguments passed to pydub.AudioSegment.export.
    """
    self._segment.export(out_f, format=format, **kwargs)
from_file(file, format=None, **kwargs) staticmethod

Wrapper: Load an audio file into a TNHAudioSegment.

Parameters:

Name Type Description Default
file str | Path | BytesIO

Path to the audio file.

required
format str | None

Optional audio format (e.g., 'mp3', 'wav'). If None, pydub will attempt to infer it.

None
**kwargs

Additional keyword arguments passed to pydub.AudioSegment.from_file.

{}

Returns:

Type Description
TNHAudioSegment

TNHAudioSegment instance containing the loaded audio.

Source code in src/tnh_scholar/utils/tnh_audio_segment.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
@staticmethod
def from_file(file: str | Path | BytesIO, format: str | None = None, **kwargs) -> "TNHAudioSegment":
    """
    Wrapper: Load an audio file into a TNHAudioSegment.

    Args:
        file: Path to the audio file.
        format: Optional audio format (e.g., 'mp3', 'wav'). If None, pydub will attempt to infer it.
        **kwargs: Additional keyword arguments passed to pydub.AudioSegment.from_file.

    Returns:
        TNHAudioSegment instance containing the loaded audio.
    """
    return TNHAudioSegment(_AudioSegment.from_file(file, format=format, **kwargs))
silent(duration) staticmethod
Source code in src/tnh_scholar/utils/tnh_audio_segment.py
55
56
57
@staticmethod
def silent(duration: int) -> "TNHAudioSegment":
    return TNHAudioSegment(_AudioSegment.silent(duration=duration))

user_io_utils

get_single_char(prompt=None)

Get a single character from input, adapting to the execution environment.

Parameters:

Name Type Description Default
prompt Optional[str]

Optional prompt to display before getting input

None

Returns:

Type Description
str

A single character string from user input

Note
  • In terminal environments, uses raw input mode without requiring Enter
  • In Jupyter/IPython, falls back to regular input with message about Enter
Source code in src/tnh_scholar/utils/user_io_utils.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def get_single_char(prompt: Optional[str] = None) -> str:
    """
    Get a single character from input, adapting to the execution environment.

    Args:
        prompt: Optional prompt to display before getting input

    Returns:
        A single character string from user input

    Note:
        - In terminal environments, uses raw input mode without requiring Enter
        - In Jupyter/IPython, falls back to regular input with message about Enter
    """
    # Check if we're in IPython/Jupyter
    is_notebook = hasattr(sys, 'ps1') or bool(sys.flags.interactive)

    if prompt:
        print(prompt, end='', flush=True)

    if is_notebook:
        # Jupyter/IPython environment - use regular input
        entry = input("Single character input required ")
        return entry[0] if entry else "\n" # Use newline if no entry

    # Terminal environment
    if os.name == "nt":  # Windows
        import msvcrt
        return msvcrt.getch().decode("utf-8")
    else:  # Unix-like
        import termios
        import tty

        try:
            fd = sys.stdin.fileno()
            old_settings = termios.tcgetattr(fd)
            try:
                tty.setraw(fd)
                char = sys.stdin.read(1)
            finally:
                termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
            return char
        except termios.error:
            # Fallback if terminal handling fails
            return input("Single character input required ")[0]
get_user_confirmation(prompt, default=True)

Prompt the user for a yes/no confirmation with single-character input. Cross-platform implementation. Returns True if 'y' is entered, and False if 'n' Allows for default value if return is entered.

Example usage if get_user_confirmation("Do you want to continue"): print("Continuing...") else: print("Exiting...")

Source code in src/tnh_scholar/utils/user_io_utils.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def get_user_confirmation(prompt: str, default: bool = True) -> bool:
    """
    Prompt the user for a yes/no confirmation with single-character input.
    Cross-platform implementation. Returns True if 'y' is entered, and False if 'n'
    Allows for default value if return is entered.

    Example usage
        if get_user_confirmation("Do you want to continue"):
            print("Continuing...")
        else:
            print("Exiting...")
    """
    print(f"{prompt} ", end="", flush=True)

    while True:
        char = get_single_char().lower()
        if char == "y":
            print(char)  # Echo the choice
            return True
        elif char == "n":
            print(char)
            return False
        elif char in ("\r", "\n"):  # Enter key (use default)
            print()  # Add a newline
            return default
        else:
            print(
                f"\nInvalid input: {char}. Please type 'y' or 'n': ", end="", flush=True
            )

validate

OCR_ENV_VARS = {'GOOGLE_APPLICATION_CREDENTIALS'} module-attribute
OPENAI_ENV_VARS = {'OPENAI_API_KEY'} module-attribute
logger = get_child_logger(__name__) module-attribute
check_env(required_vars, feature='this feature', output=True)

Check environment variables and provide user-friendly error messages.

Parameters:

Name Type Description Default
required_vars Set[str]

Set of environment variable names to check

required
feature str

Description of feature requiring these variables

'this feature'

Returns:

Name Type Description
bool bool

True if all required variables are set

Source code in src/tnh_scholar/utils/validate.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def check_env(required_vars: Set[str], feature: str = "this feature", output: bool = True) -> bool:
    """
    Check environment variables and provide user-friendly error messages.

    Args:
        required_vars: Set of environment variable names to check
        feature: Description of feature requiring these variables

    Returns:
        bool: True if all required variables are set
    """
    if missing := [var for var in required_vars if not os.getenv(var)]:
        if output:
            message = get_env_message(missing, feature)
            logger.error(f"Missing environment variables: {', '.join(missing)}")
            print(message, file=sys.stderr)
        return False
    return True
check_ocr_env(output=True)

Check OCR processing requirements.

Source code in src/tnh_scholar/utils/validate.py
57
58
59
def check_ocr_env(output: bool = True) -> bool:
    """Check OCR processing requirements."""
    return check_env(OCR_ENV_VARS, "OCR processing", output=output)
check_openai_env(output=True)

Check OpenAI API requirements.

Source code in src/tnh_scholar/utils/validate.py
53
54
55
def check_openai_env(output: bool = True) -> bool:
    """Check OpenAI API requirements."""
    return check_env(OPENAI_ENV_VARS, "OpenAI API access", output=output)
get_env_message(missing_vars, feature='this feature')

Generate user-friendly environment setup message.

Parameters:

Name Type Description Default
missing_vars List[str]

List of missing environment variable names

required
feature str

Name of feature requiring the variables

'this feature'

Returns:

Type Description
str

Formatted error message with setup instructions

Source code in src/tnh_scholar/utils/validate.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def get_env_message(missing_vars: List[str], feature: str = "this feature") -> str:
    """Generate user-friendly environment setup message.

    Args:
        missing_vars: List of missing environment variable names
        feature: Name of feature requiring the variables

    Returns:
        Formatted error message with setup instructions
    """
    export_cmds = " ".join(f"{var}=your_{var.lower()}_here" for var in missing_vars)

    return "\n".join([
        f"\nEnvironment Error: Missing required variables for {feature}:",
        ", ".join(missing_vars),
        "\nSet variables in your shell:",
        f"export {export_cmds}",
        "\nSee documentation for details.",
        "\nFor development: Add to .env file in project root.\n"
    ])

version_check

Version checker package for monitoring package version compatibility.

__all__ = ['PackageVersionChecker', 'VersionCheckerConfig', 'VersionStrategy', 'Result', 'PackageInfo'] module-attribute
PackageInfo dataclass

Information about a package and its versions.

Source code in src/tnh_scholar/utils/version_check/models.py
 7
 8
 9
10
11
12
13
14
@dataclass
class PackageInfo:
    """Information about a package and its versions."""

    name: str
    installed_version: Optional[str] = None
    latest_version: Optional[str] = None
    required_version: Optional[str] = None
installed_version = None class-attribute instance-attribute
latest_version = None class-attribute instance-attribute
name instance-attribute
required_version = None class-attribute instance-attribute
__init__(name, installed_version=None, latest_version=None, required_version=None)
PackageVersionChecker

Main class for checking package versions against requirements.

Source code in src/tnh_scholar/utils/version_check/checker.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class PackageVersionChecker:
    """Main class for checking package versions against requirements."""

    def __init__(self, 
                 provider: Optional[VersionProvider] = None,
                 cache: Optional[VersionCache] = None):
        self.provider = provider or StandardVersionProvider()
        self.cache = cache or VersionCache()

    # TODO make this method more modular and extract out complexity  
    # also check why parse_vdiff is not being used.
    def check_version(self, 
                      package_name: str, 
                      config: Optional[VersionCheckerConfig] = None) -> Result:
        """Check if package meets version requirements based on config."""
        config = config or VersionCheckerConfig()

        try:
            # Get versions
            installed = self.provider.get_installed_version(package_name)
            latest = self.provider.get_latest_version(package_name)

            # Default values
            is_compatible = True
            needs_update = installed < latest
            warning_level = None
            diff_details = None

            # Check based on strategy
            if config.strategy == VersionStrategy.MINIMUM:
                is_compatible = check_minimum_version(installed, config.get_required_version())

            elif config.strategy == VersionStrategy.EXACT:
                is_compatible = check_exact_version(installed, config.get_required_version())

            elif config.strategy == VersionStrategy.VERSION_DIFF:
                # Check warning threshold
                if config.vdiff_warn_matrix:
                    warn_within_limits, diff_details = check_version_diff(
                        installed, latest, config.vdiff_warn_matrix)
                    if not warn_within_limits:
                        # Determine warning level based on which component exceeded threshold
                        if diff_details and "major" in diff_details and diff_details["major"] > 0:
                            warning_level = "MAJOR"
                        elif diff_details and "minor" in diff_details and diff_details["minor"] > 0:
                            warning_level = "MINOR"
                        else:
                            warning_level = "MICRO"

                # Check failure threshold
                if config.vdiff_fail_matrix:
                    fail_within_limits, diff_details = check_version_diff(
                        installed, latest, config.vdiff_fail_matrix)
                    is_compatible = fail_within_limits

            # Create package info
            package_info = PackageInfo(
                name=package_name,
                installed_version=str(installed),
                latest_version=str(latest),
                required_version=str(config.get_required_version()) if config.requirement else None
            )

            # Create and return result
            return Result(
                is_compatible=is_compatible,
                needs_update=needs_update,
                package_info=package_info,
                warning_level=warning_level,
                diff_details=diff_details
            )

        except Exception as e:
            # Handle errors based on configuration
            if config.fail_on_error:
                raise
            return Result(
                is_compatible=False,
                needs_update=False,
                package_info=PackageInfo(name=package_name),
                error=str(e)
            )
cache = cache or VersionCache() instance-attribute
provider = provider or StandardVersionProvider() instance-attribute
__init__(provider=None, cache=None)
Source code in src/tnh_scholar/utils/version_check/checker.py
19
20
21
22
23
def __init__(self, 
             provider: Optional[VersionProvider] = None,
             cache: Optional[VersionCache] = None):
    self.provider = provider or StandardVersionProvider()
    self.cache = cache or VersionCache()
check_version(package_name, config=None)

Check if package meets version requirements based on config.

Source code in src/tnh_scholar/utils/version_check/checker.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def check_version(self, 
                  package_name: str, 
                  config: Optional[VersionCheckerConfig] = None) -> Result:
    """Check if package meets version requirements based on config."""
    config = config or VersionCheckerConfig()

    try:
        # Get versions
        installed = self.provider.get_installed_version(package_name)
        latest = self.provider.get_latest_version(package_name)

        # Default values
        is_compatible = True
        needs_update = installed < latest
        warning_level = None
        diff_details = None

        # Check based on strategy
        if config.strategy == VersionStrategy.MINIMUM:
            is_compatible = check_minimum_version(installed, config.get_required_version())

        elif config.strategy == VersionStrategy.EXACT:
            is_compatible = check_exact_version(installed, config.get_required_version())

        elif config.strategy == VersionStrategy.VERSION_DIFF:
            # Check warning threshold
            if config.vdiff_warn_matrix:
                warn_within_limits, diff_details = check_version_diff(
                    installed, latest, config.vdiff_warn_matrix)
                if not warn_within_limits:
                    # Determine warning level based on which component exceeded threshold
                    if diff_details and "major" in diff_details and diff_details["major"] > 0:
                        warning_level = "MAJOR"
                    elif diff_details and "minor" in diff_details and diff_details["minor"] > 0:
                        warning_level = "MINOR"
                    else:
                        warning_level = "MICRO"

            # Check failure threshold
            if config.vdiff_fail_matrix:
                fail_within_limits, diff_details = check_version_diff(
                    installed, latest, config.vdiff_fail_matrix)
                is_compatible = fail_within_limits

        # Create package info
        package_info = PackageInfo(
            name=package_name,
            installed_version=str(installed),
            latest_version=str(latest),
            required_version=str(config.get_required_version()) if config.requirement else None
        )

        # Create and return result
        return Result(
            is_compatible=is_compatible,
            needs_update=needs_update,
            package_info=package_info,
            warning_level=warning_level,
            diff_details=diff_details
        )

    except Exception as e:
        # Handle errors based on configuration
        if config.fail_on_error:
            raise
        return Result(
            is_compatible=False,
            needs_update=False,
            package_info=PackageInfo(name=package_name),
            error=str(e)
        )
Result dataclass

Result of a version check operation.

Source code in src/tnh_scholar/utils/version_check/models.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
@dataclass
class Result:
    """Result of a version check operation."""

    is_compatible: bool
    needs_update: bool
    package_info: PackageInfo
    error: Optional[str] = None
    warning_level: Optional[str] = None
    diff_details: Optional[Dict[str, int]] = None

    def get_upgrade_command(self) -> str:
        """Return pip command to upgrade package."""
        if not self.package_info or not self.package_info.name:
            return ""

        if self.package_info.latest_version:
            return f"pip install --upgrade {self.package_info.name}=={self.package_info.latest_version}"
        else:
            return f"pip install --upgrade {self.package_info.name}"
diff_details = None class-attribute instance-attribute
error = None class-attribute instance-attribute
is_compatible instance-attribute
needs_update instance-attribute
package_info instance-attribute
warning_level = None class-attribute instance-attribute
__init__(is_compatible, needs_update, package_info, error=None, warning_level=None, diff_details=None)
get_upgrade_command()

Return pip command to upgrade package.

Source code in src/tnh_scholar/utils/version_check/models.py
27
28
29
30
31
32
33
34
35
def get_upgrade_command(self) -> str:
    """Return pip command to upgrade package."""
    if not self.package_info or not self.package_info.name:
        return ""

    if self.package_info.latest_version:
        return f"pip install --upgrade {self.package_info.name}=={self.package_info.latest_version}"
    else:
        return f"pip install --upgrade {self.package_info.name}"
VersionCheckerConfig

Configuration for version checking behavior.

Source code in src/tnh_scholar/utils/version_check/config.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class VersionCheckerConfig:
    """Configuration for version checking behavior."""

    def __init__(self,
                 strategy: VersionStrategy = VersionStrategy.MINIMUM,
                 requirement: str = "",
                 fail_on_error: bool = False,
                 cache_duration: int = 3600,  # 1 hour
                 network_timeout: int = 5,    # seconds
                 vdiff_warn_matrix: Optional[str] = None,
                 vdiff_fail_matrix: Optional[str] = None):
        """Initialize version checker configuration."""
        self.strategy = strategy
        self.requirement = requirement
        self.fail_on_error = fail_on_error
        self.cache_duration = cache_duration
        self.network_timeout = network_timeout
        self.vdiff_warn_matrix = vdiff_warn_matrix
        self.vdiff_fail_matrix = vdiff_fail_matrix

    def get_required_version(self) -> Optional[Version]:
        """Get required version as a Version object."""
        return Version(self.requirement) if self.requirement else None
cache_duration = cache_duration instance-attribute
fail_on_error = fail_on_error instance-attribute
network_timeout = network_timeout instance-attribute
requirement = requirement instance-attribute
strategy = strategy instance-attribute
vdiff_fail_matrix = vdiff_fail_matrix instance-attribute
vdiff_warn_matrix = vdiff_warn_matrix instance-attribute
__init__(strategy=VersionStrategy.MINIMUM, requirement='', fail_on_error=False, cache_duration=3600, network_timeout=5, vdiff_warn_matrix=None, vdiff_fail_matrix=None)

Initialize version checker configuration.

Source code in src/tnh_scholar/utils/version_check/config.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(self,
             strategy: VersionStrategy = VersionStrategy.MINIMUM,
             requirement: str = "",
             fail_on_error: bool = False,
             cache_duration: int = 3600,  # 1 hour
             network_timeout: int = 5,    # seconds
             vdiff_warn_matrix: Optional[str] = None,
             vdiff_fail_matrix: Optional[str] = None):
    """Initialize version checker configuration."""
    self.strategy = strategy
    self.requirement = requirement
    self.fail_on_error = fail_on_error
    self.cache_duration = cache_duration
    self.network_timeout = network_timeout
    self.vdiff_warn_matrix = vdiff_warn_matrix
    self.vdiff_fail_matrix = vdiff_fail_matrix
get_required_version()

Get required version as a Version object.

Source code in src/tnh_scholar/utils/version_check/config.py
37
38
39
def get_required_version(self) -> Optional[Version]:
    """Get required version as a Version object."""
    return Version(self.requirement) if self.requirement else None
VersionStrategy

Bases: Enum

Enumeration of version checking strategies.

Source code in src/tnh_scholar/utils/version_check/config.py
 9
10
11
12
13
14
15
class VersionStrategy(Enum):
    """Enumeration of version checking strategies."""
    MINIMUM = "minimum"    # Package version must be >= requirement
    EXACT = "exact"        # Package version must be == requirement
    LATEST = "latest"      # Package version should be the latest available
    RANGE = "range"        # Package version must be within a specified range
    VERSION_DIFF = "vdiff" # Check version difference against thresholds
EXACT = 'exact' class-attribute instance-attribute
LATEST = 'latest' class-attribute instance-attribute
MINIMUM = 'minimum' class-attribute instance-attribute
RANGE = 'range' class-attribute instance-attribute
VERSION_DIFF = 'vdiff' class-attribute instance-attribute
cache

Simple caching mechanism for version information.

VersionCache

Simple time-based cache for version information.

Source code in src/tnh_scholar/utils/version_check/cache.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class VersionCache:
    """Simple time-based cache for version information."""

    def __init__(self, cache_duration: int = 3600):
        """Initialize cache with specified expiration time in seconds."""
        self.cache: Dict[str, Version] = {}
        self.timestamps: Dict[str, float] = {}
        self.cache_duration = cache_duration

    def get(self, key: str) -> Optional[Version]:
        """Get cached version if still valid."""
        return self.cache.get(key) if self.is_valid(key) else None

    def set(self, key: str, value: Version) -> None:
        """Cache version with current timestamp."""
        self.cache[key] = value
        self.timestamps[key] = time.time()

    def is_valid(self, key: str) -> bool:
        """Check if cached value is still valid."""
        if key not in self.timestamps:
            return False
        age = time.time() - self.timestamps[key]
        return age < self.cache_duration
cache = {} instance-attribute
cache_duration = cache_duration instance-attribute
timestamps = {} instance-attribute
__init__(cache_duration=3600)

Initialize cache with specified expiration time in seconds.

Source code in src/tnh_scholar/utils/version_check/cache.py
13
14
15
16
17
def __init__(self, cache_duration: int = 3600):
    """Initialize cache with specified expiration time in seconds."""
    self.cache: Dict[str, Version] = {}
    self.timestamps: Dict[str, float] = {}
    self.cache_duration = cache_duration
get(key)

Get cached version if still valid.

Source code in src/tnh_scholar/utils/version_check/cache.py
19
20
21
def get(self, key: str) -> Optional[Version]:
    """Get cached version if still valid."""
    return self.cache.get(key) if self.is_valid(key) else None
is_valid(key)

Check if cached value is still valid.

Source code in src/tnh_scholar/utils/version_check/cache.py
28
29
30
31
32
33
def is_valid(self, key: str) -> bool:
    """Check if cached value is still valid."""
    if key not in self.timestamps:
        return False
    age = time.time() - self.timestamps[key]
    return age < self.cache_duration
set(key, value)

Cache version with current timestamp.

Source code in src/tnh_scholar/utils/version_check/cache.py
23
24
25
26
def set(self, key: str, value: Version) -> None:
    """Cache version with current timestamp."""
    self.cache[key] = value
    self.timestamps[key] = time.time()
checker

Main version checker implementation.

PackageVersionChecker

Main class for checking package versions against requirements.

Source code in src/tnh_scholar/utils/version_check/checker.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class PackageVersionChecker:
    """Main class for checking package versions against requirements."""

    def __init__(self, 
                 provider: Optional[VersionProvider] = None,
                 cache: Optional[VersionCache] = None):
        self.provider = provider or StandardVersionProvider()
        self.cache = cache or VersionCache()

    # TODO make this method more modular and extract out complexity  
    # also check why parse_vdiff is not being used.
    def check_version(self, 
                      package_name: str, 
                      config: Optional[VersionCheckerConfig] = None) -> Result:
        """Check if package meets version requirements based on config."""
        config = config or VersionCheckerConfig()

        try:
            # Get versions
            installed = self.provider.get_installed_version(package_name)
            latest = self.provider.get_latest_version(package_name)

            # Default values
            is_compatible = True
            needs_update = installed < latest
            warning_level = None
            diff_details = None

            # Check based on strategy
            if config.strategy == VersionStrategy.MINIMUM:
                is_compatible = check_minimum_version(installed, config.get_required_version())

            elif config.strategy == VersionStrategy.EXACT:
                is_compatible = check_exact_version(installed, config.get_required_version())

            elif config.strategy == VersionStrategy.VERSION_DIFF:
                # Check warning threshold
                if config.vdiff_warn_matrix:
                    warn_within_limits, diff_details = check_version_diff(
                        installed, latest, config.vdiff_warn_matrix)
                    if not warn_within_limits:
                        # Determine warning level based on which component exceeded threshold
                        if diff_details and "major" in diff_details and diff_details["major"] > 0:
                            warning_level = "MAJOR"
                        elif diff_details and "minor" in diff_details and diff_details["minor"] > 0:
                            warning_level = "MINOR"
                        else:
                            warning_level = "MICRO"

                # Check failure threshold
                if config.vdiff_fail_matrix:
                    fail_within_limits, diff_details = check_version_diff(
                        installed, latest, config.vdiff_fail_matrix)
                    is_compatible = fail_within_limits

            # Create package info
            package_info = PackageInfo(
                name=package_name,
                installed_version=str(installed),
                latest_version=str(latest),
                required_version=str(config.get_required_version()) if config.requirement else None
            )

            # Create and return result
            return Result(
                is_compatible=is_compatible,
                needs_update=needs_update,
                package_info=package_info,
                warning_level=warning_level,
                diff_details=diff_details
            )

        except Exception as e:
            # Handle errors based on configuration
            if config.fail_on_error:
                raise
            return Result(
                is_compatible=False,
                needs_update=False,
                package_info=PackageInfo(name=package_name),
                error=str(e)
            )
cache = cache or VersionCache() instance-attribute
provider = provider or StandardVersionProvider() instance-attribute
__init__(provider=None, cache=None)
Source code in src/tnh_scholar/utils/version_check/checker.py
19
20
21
22
23
def __init__(self, 
             provider: Optional[VersionProvider] = None,
             cache: Optional[VersionCache] = None):
    self.provider = provider or StandardVersionProvider()
    self.cache = cache or VersionCache()
check_version(package_name, config=None)

Check if package meets version requirements based on config.

Source code in src/tnh_scholar/utils/version_check/checker.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def check_version(self, 
                  package_name: str, 
                  config: Optional[VersionCheckerConfig] = None) -> Result:
    """Check if package meets version requirements based on config."""
    config = config or VersionCheckerConfig()

    try:
        # Get versions
        installed = self.provider.get_installed_version(package_name)
        latest = self.provider.get_latest_version(package_name)

        # Default values
        is_compatible = True
        needs_update = installed < latest
        warning_level = None
        diff_details = None

        # Check based on strategy
        if config.strategy == VersionStrategy.MINIMUM:
            is_compatible = check_minimum_version(installed, config.get_required_version())

        elif config.strategy == VersionStrategy.EXACT:
            is_compatible = check_exact_version(installed, config.get_required_version())

        elif config.strategy == VersionStrategy.VERSION_DIFF:
            # Check warning threshold
            if config.vdiff_warn_matrix:
                warn_within_limits, diff_details = check_version_diff(
                    installed, latest, config.vdiff_warn_matrix)
                if not warn_within_limits:
                    # Determine warning level based on which component exceeded threshold
                    if diff_details and "major" in diff_details and diff_details["major"] > 0:
                        warning_level = "MAJOR"
                    elif diff_details and "minor" in diff_details and diff_details["minor"] > 0:
                        warning_level = "MINOR"
                    else:
                        warning_level = "MICRO"

            # Check failure threshold
            if config.vdiff_fail_matrix:
                fail_within_limits, diff_details = check_version_diff(
                    installed, latest, config.vdiff_fail_matrix)
                is_compatible = fail_within_limits

        # Create package info
        package_info = PackageInfo(
            name=package_name,
            installed_version=str(installed),
            latest_version=str(latest),
            required_version=str(config.get_required_version()) if config.requirement else None
        )

        # Create and return result
        return Result(
            is_compatible=is_compatible,
            needs_update=needs_update,
            package_info=package_info,
            warning_level=warning_level,
            diff_details=diff_details
        )

    except Exception as e:
        # Handle errors based on configuration
        if config.fail_on_error:
            raise
        return Result(
            is_compatible=False,
            needs_update=False,
            package_info=PackageInfo(name=package_name),
            error=str(e)
        )
cli

Command-line interface for version checking (stub for future implementation).

main()

Command-line interface for version checking.

Source code in src/tnh_scholar/utils/version_check/cli.py
3
4
5
def main():
    """Command-line interface for version checking."""
    raise NotImplementedError("CLI functionality is not yet implemented")
config

Configuration classes for version checking.

VersionCheckerConfig

Configuration for version checking behavior.

Source code in src/tnh_scholar/utils/version_check/config.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class VersionCheckerConfig:
    """Configuration for version checking behavior."""

    def __init__(self,
                 strategy: VersionStrategy = VersionStrategy.MINIMUM,
                 requirement: str = "",
                 fail_on_error: bool = False,
                 cache_duration: int = 3600,  # 1 hour
                 network_timeout: int = 5,    # seconds
                 vdiff_warn_matrix: Optional[str] = None,
                 vdiff_fail_matrix: Optional[str] = None):
        """Initialize version checker configuration."""
        self.strategy = strategy
        self.requirement = requirement
        self.fail_on_error = fail_on_error
        self.cache_duration = cache_duration
        self.network_timeout = network_timeout
        self.vdiff_warn_matrix = vdiff_warn_matrix
        self.vdiff_fail_matrix = vdiff_fail_matrix

    def get_required_version(self) -> Optional[Version]:
        """Get required version as a Version object."""
        return Version(self.requirement) if self.requirement else None
cache_duration = cache_duration instance-attribute
fail_on_error = fail_on_error instance-attribute
network_timeout = network_timeout instance-attribute
requirement = requirement instance-attribute
strategy = strategy instance-attribute
vdiff_fail_matrix = vdiff_fail_matrix instance-attribute
vdiff_warn_matrix = vdiff_warn_matrix instance-attribute
__init__(strategy=VersionStrategy.MINIMUM, requirement='', fail_on_error=False, cache_duration=3600, network_timeout=5, vdiff_warn_matrix=None, vdiff_fail_matrix=None)

Initialize version checker configuration.

Source code in src/tnh_scholar/utils/version_check/config.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(self,
             strategy: VersionStrategy = VersionStrategy.MINIMUM,
             requirement: str = "",
             fail_on_error: bool = False,
             cache_duration: int = 3600,  # 1 hour
             network_timeout: int = 5,    # seconds
             vdiff_warn_matrix: Optional[str] = None,
             vdiff_fail_matrix: Optional[str] = None):
    """Initialize version checker configuration."""
    self.strategy = strategy
    self.requirement = requirement
    self.fail_on_error = fail_on_error
    self.cache_duration = cache_duration
    self.network_timeout = network_timeout
    self.vdiff_warn_matrix = vdiff_warn_matrix
    self.vdiff_fail_matrix = vdiff_fail_matrix
get_required_version()

Get required version as a Version object.

Source code in src/tnh_scholar/utils/version_check/config.py
37
38
39
def get_required_version(self) -> Optional[Version]:
    """Get required version as a Version object."""
    return Version(self.requirement) if self.requirement else None
VersionStrategy

Bases: Enum

Enumeration of version checking strategies.

Source code in src/tnh_scholar/utils/version_check/config.py
 9
10
11
12
13
14
15
class VersionStrategy(Enum):
    """Enumeration of version checking strategies."""
    MINIMUM = "minimum"    # Package version must be >= requirement
    EXACT = "exact"        # Package version must be == requirement
    LATEST = "latest"      # Package version should be the latest available
    RANGE = "range"        # Package version must be within a specified range
    VERSION_DIFF = "vdiff" # Check version difference against thresholds
EXACT = 'exact' class-attribute instance-attribute
LATEST = 'latest' class-attribute instance-attribute
MINIMUM = 'minimum' class-attribute instance-attribute
RANGE = 'range' class-attribute instance-attribute
VERSION_DIFF = 'vdiff' class-attribute instance-attribute
models

Data models for version checking results.

PackageInfo dataclass

Information about a package and its versions.

Source code in src/tnh_scholar/utils/version_check/models.py
 7
 8
 9
10
11
12
13
14
@dataclass
class PackageInfo:
    """Information about a package and its versions."""

    name: str
    installed_version: Optional[str] = None
    latest_version: Optional[str] = None
    required_version: Optional[str] = None
installed_version = None class-attribute instance-attribute
latest_version = None class-attribute instance-attribute
name instance-attribute
required_version = None class-attribute instance-attribute
__init__(name, installed_version=None, latest_version=None, required_version=None)
Result dataclass

Result of a version check operation.

Source code in src/tnh_scholar/utils/version_check/models.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
@dataclass
class Result:
    """Result of a version check operation."""

    is_compatible: bool
    needs_update: bool
    package_info: PackageInfo
    error: Optional[str] = None
    warning_level: Optional[str] = None
    diff_details: Optional[Dict[str, int]] = None

    def get_upgrade_command(self) -> str:
        """Return pip command to upgrade package."""
        if not self.package_info or not self.package_info.name:
            return ""

        if self.package_info.latest_version:
            return f"pip install --upgrade {self.package_info.name}=={self.package_info.latest_version}"
        else:
            return f"pip install --upgrade {self.package_info.name}"
diff_details = None class-attribute instance-attribute
error = None class-attribute instance-attribute
is_compatible instance-attribute
needs_update instance-attribute
package_info instance-attribute
warning_level = None class-attribute instance-attribute
__init__(is_compatible, needs_update, package_info, error=None, warning_level=None, diff_details=None)
get_upgrade_command()

Return pip command to upgrade package.

Source code in src/tnh_scholar/utils/version_check/models.py
27
28
29
30
31
32
33
34
35
def get_upgrade_command(self) -> str:
    """Return pip command to upgrade package."""
    if not self.package_info or not self.package_info.name:
        return ""

    if self.package_info.latest_version:
        return f"pip install --upgrade {self.package_info.name}=={self.package_info.latest_version}"
    else:
        return f"pip install --upgrade {self.package_info.name}"
providers

Version provider implementations for retrieving package versions.

StandardVersionProvider

Bases: VersionProvider

Standard implementation of version provider using importlib and PyPI.

Source code in src/tnh_scholar/utils/version_check/providers.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class StandardVersionProvider(VersionProvider):
    """Standard implementation of version provider using importlib and PyPI."""

    def __init__(self, cache: Optional[VersionCache] = None, timeout: int = 5):
        self.cache = cache or VersionCache()
        self.timeout = timeout
        self.pypi_url_template = "https://pypi.org/pypi/{package}/json"

    def get_installed_version(self, package_name: str) -> Version:
        """Get installed package version."""
        try:
            if version_str := str(importlib.metadata.version(package_name)):
                return Version(version_str)
            else:
                raise InvalidVersion(f"{package_name} version string is empty")
        except importlib.metadata.PackageNotFoundError as e:
            raise ImportError(f"{package_name} is not installed") from e
        except InvalidVersion as e:
            raise InvalidVersion(f"Invalid version for {package_name}: {e}") from e

    def get_latest_version(self, package_name: str) -> Version:
        """Get latest available package version from PyPI."""
        # Check cache first
        if cached_version := self.cache.get(f"{package_name}_latest"):
            return cached_version

        # Fetch from PyPI
        url = self.pypi_url_template.format(package=package_name)
        try:
            return self._send_url_request(url, package_name)
        except requests.RequestException as e:
            raise requests.RequestException(
                f"Failed to fetch {package_name} version from PyPI: {e}"
            ) from e

    def _send_url_request(self, url, package_name):
        response = requests.get(url, timeout=self.timeout)
        response.raise_for_status()
        version_str = response.json()["info"]["version"]
        version = Version(version_str)

        # Cache the result
        self.cache.set(f"{package_name}_latest", version)

        return version
cache = cache or VersionCache() instance-attribute
pypi_url_template = 'https://pypi.org/pypi/{package}/json' instance-attribute
timeout = timeout instance-attribute
__init__(cache=None, timeout=5)
Source code in src/tnh_scholar/utils/version_check/providers.py
30
31
32
33
def __init__(self, cache: Optional[VersionCache] = None, timeout: int = 5):
    self.cache = cache or VersionCache()
    self.timeout = timeout
    self.pypi_url_template = "https://pypi.org/pypi/{package}/json"
get_installed_version(package_name)

Get installed package version.

Source code in src/tnh_scholar/utils/version_check/providers.py
35
36
37
38
39
40
41
42
43
44
45
def get_installed_version(self, package_name: str) -> Version:
    """Get installed package version."""
    try:
        if version_str := str(importlib.metadata.version(package_name)):
            return Version(version_str)
        else:
            raise InvalidVersion(f"{package_name} version string is empty")
    except importlib.metadata.PackageNotFoundError as e:
        raise ImportError(f"{package_name} is not installed") from e
    except InvalidVersion as e:
        raise InvalidVersion(f"Invalid version for {package_name}: {e}") from e
get_latest_version(package_name)

Get latest available package version from PyPI.

Source code in src/tnh_scholar/utils/version_check/providers.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def get_latest_version(self, package_name: str) -> Version:
    """Get latest available package version from PyPI."""
    # Check cache first
    if cached_version := self.cache.get(f"{package_name}_latest"):
        return cached_version

    # Fetch from PyPI
    url = self.pypi_url_template.format(package=package_name)
    try:
        return self._send_url_request(url, package_name)
    except requests.RequestException as e:
        raise requests.RequestException(
            f"Failed to fetch {package_name} version from PyPI: {e}"
        ) from e
VersionProvider

Bases: ABC

Interface for retrieving package version information.

Source code in src/tnh_scholar/utils/version_check/providers.py
13
14
15
16
17
18
19
20
21
22
23
24
class VersionProvider(ABC):
    """Interface for retrieving package version information."""

    @abstractmethod
    def get_installed_version(self, package_name: str) -> Version:
        """Get installed package version."""
        pass

    @abstractmethod
    def get_latest_version(self, package_name: str) -> Version:
        """Get latest available package version."""
        pass
get_installed_version(package_name) abstractmethod

Get installed package version.

Source code in src/tnh_scholar/utils/version_check/providers.py
16
17
18
19
@abstractmethod
def get_installed_version(self, package_name: str) -> Version:
    """Get installed package version."""
    pass
get_latest_version(package_name) abstractmethod

Get latest available package version.

Source code in src/tnh_scholar/utils/version_check/providers.py
21
22
23
24
@abstractmethod
def get_latest_version(self, package_name: str) -> Version:
    """Get latest available package version."""
    pass
strategies

Version comparison strategies for package version checking.

check_exact_version(installed, required)

Check if installed version exactly matches requirement.

Source code in src/tnh_scholar/utils/version_check/strategies.py
12
13
14
def check_exact_version(installed: Version, required: Optional[Version]) -> bool:
    """Check if installed version exactly matches requirement."""
    return True if required is None else installed == required
check_minimum_version(installed, required)

Check if installed version meets minimum requirement.

Source code in src/tnh_scholar/utils/version_check/strategies.py
 8
 9
10
def check_minimum_version(installed: Version, required: Optional[Version]) -> bool:
    """Check if installed version meets minimum requirement."""
    return True if required is None else installed >= required
check_version_diff(installed, reference, vdiff_matrix)

Check if version difference is within specified limits.

Source code in src/tnh_scholar/utils/version_check/strategies.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def check_version_diff(
    installed: Version,
    reference: Version,
    vdiff_matrix: str
) -> Tuple[bool, Dict[str, int]]:
    """Check if version difference is within specified limits."""
    # Calculate actual differences
    major_diff = abs(reference.major - installed.major)
    minor_diff = abs(reference.minor - installed.minor) if reference.major == installed.major else 0
    micro_diff = abs(reference.micro - installed.micro) if (reference.major == installed.major and 
                                                          reference.minor == installed.minor) else 0

    diff_details = {
        "major": major_diff,
        "minor": minor_diff,
        "micro": micro_diff
    }

    # If no matrix provided, differences are acceptable
    if not vdiff_matrix:
        return True, diff_details

    # Parse matrix
    major_limit, minor_limit, micro_limit = parse_vdiff_matrix(vdiff_matrix)

    # Check limits (None means no limit)
    if major_limit is not None and major_diff > major_limit:
        return False, diff_details

    if minor_limit is not None and minor_diff > minor_limit:
        return False, diff_details

    if micro_limit is not None and micro_diff > micro_limit:
        return False, diff_details

    return True, diff_details
parse_vdiff_matrix(matrix_str)

Parse a version difference matrix string.

Source code in src/tnh_scholar/utils/version_check/strategies.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def parse_vdiff_matrix(matrix_str: str) -> Tuple[Optional[int], Optional[int], Optional[int]]:
    """Parse a version difference matrix string."""
    parts = matrix_str.split(".")
    if len(parts) != 3:
        raise ValueError(f"Invalid version difference matrix: {matrix_str}")

    limits = []
    for part in parts:
        if part == "*":
            limits.append(None)  # No limit
        else:
            try:
                limits.append(int(part))
            except ValueError as e:
                raise ValueError(f"Invalid version component: {part}") from e

    return tuple(limits)  # Tuple[Optional[int], Optional[int], Optional[int]]

webhook_server

WebhookServer

A generic webhook server that can receive callbacks from external services.

Source code in src/tnh_scholar/utils/webhook_server.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
class WebhookServer:
    """A generic webhook server that can receive callbacks from external services."""

    def __init__(self, port: int = 5050):
        """
        Initialize webhook server with configuration.

        Args:
            port: The port to run the Flask server on
        """
        self.port = port
        self.app = self._create_flask_app()
        self.webhook_received = Condition()
        self.webhook_data = None
        self.flask_running = Event()
        self.flask_server_thread = None
        self.tunnel_process = None

    def _create_flask_app(self) -> Flask:
        """Create and configure Flask app with webhook endpoint."""
        app = Flask(__name__)

        @app.route('/healthcheck', methods=['GET'])
        def healthcheck():
            """Simple endpoint to verify the server is running."""
            return jsonify({
                'status': 'ok',
                'timestamp': datetime.now().isoformat(),
                'webhook_received': self.webhook_data is not None
            })

        @app.route('/webhook', methods=['POST'])
        def handle_webhook():
            """Receive webhook data from external services."""
            # Get JSON data from the request
            data = request.json

            # Log webhook receipt
            print("\n" + "="*40)
            print(f"WEBHOOK RECEIVED at {datetime.now().strftime('%H:%M:%S')}")

            if data is not None:
                print(f"Webhook data status: {data.get('status', 'unknown')}")

                # Update the shared state with proper synchronization
                with self.webhook_received:
                    self.webhook_data = data
                    self.webhook_received.notify_all()
                    print("Notification sent to waiting threads")
            else:
                print("Webhook received with no JSON data")

            # Always return a success response to acknowledge receipt
            return jsonify({'status': 'received'}), 200

        @app.route('/shutdown', methods=['POST'])
        def shutdown():
            """Endpoint to gracefully shut down the Flask server."""
            func = request.environ.get('werkzeug.server.shutdown')
            if func is None:
                raise RuntimeError('Not running with the Werkzeug Server')
            func()
            return 'Server shutting down...'

        return app

    def start_server(self) -> None:
        """Start Flask server in a separate thread."""
        # Check if server is already running
        if self.flask_running.is_set() and \
            self.flask_server_thread and \
                self.flask_server_thread.is_alive():
            print(f"Flask server already running on port {self.port}")
            return

        # Reset state
        self.flask_running.clear()

        # Create thread function that sets event when server starts
        def run_server():
            print(f"Starting Flask server on port {self.port}...")
            self.flask_running.set()
            self.app.run(
                host="0.0.0.0", 
                port=self.port, 
                debug=False, 
                use_reloader=False
                )
            self.flask_running.clear()
            print("Flask server has stopped")

        # Start server in a daemon thread
        self.flask_server_thread = Thread(target=run_server, daemon=True)
        self.flask_server_thread.start()

        # Wait for server to start
        if not self.flask_running.wait(timeout=5):
            raise RuntimeError("Flask server failed to start within timeout period")

        print(f"Flask server started successfully on port {self.port}")

    def shutdown_server(self) -> None:
        """Gracefully shut down the Flask server."""
        if not self.flask_running.is_set():
            print("Flask server is not running")
            return

        try:
            print("Shutting down Flask server...")
            requests.post(f"http://localhost:{self.port}/shutdown")

            # Wait for server to stop
            if self.flask_server_thread:
                self.flask_server_thread.join(timeout=5)

            if self.flask_running.is_set():
                print("WARNING: Flask server did not shut down gracefully")
            else:
                print("Flask server shut down successfully")
        except Exception as e:
            print(f"Error shutting down Flask server: {e}")

    def create_tunnel(self) -> Optional[str]:
        """
        Create a public webhook URL using py-localtunnel.

        Returns:
            Optional[str]: The public webhook URL or None if tunnel creation failed
        """
        # First check if pylt is installed
        try:
            subprocess.run(["pylt", "--version"], check=True, capture_output=True)
        except (subprocess.SubprocessError, FileNotFoundError) as e:
            print("ERROR: pylt not found. Install with: pip install pylt")
            raise RuntimeError("Tunnel not started.") from e

        print(f"Creating public tunnel to port {self.port}...")

        # Start the localtunnel process
        self.tunnel_process = subprocess.Popen(
            ["pylt", "port", str(self.port)],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True
        )

        # Give it time to establish the tunnel
        time.sleep(3)

        # Check if process started successfully
        if self.tunnel_process.poll() is not None:
            if self.tunnel_process.stderr:
                stderr = self.tunnel_process.stderr.read()
            else:
                raise RuntimeError("ERROR: Tunnel process failed: no stderr")
            if self.tunnel_process.stdout:
                stdout = self.tunnel_process.stdout.read()
            else:
                raise RuntimeError("ERROR: Tunnel process failed: no stderr")
            raise RuntimeError(
                f"ERROR: Tunnel process failed:\nSTDOUT: {stdout}\nSTDERR: {stderr}"
                )

        # Regular expression to find the URL in the output
        url_pattern = re.compile(r'https?://[^\s\'"]+')

        # Read output with timeout
        start_time = time.time()
        tunnel_url = None

        while time.time() - start_time < 15:  # Wait up to 15 seconds
            if not self.tunnel_process.stdout:
                raise RuntimeError("ERROR: tunnel process has no stdout.")

            line = self.tunnel_process.stdout.readline()
            if not line:
                time.sleep(0.1)
                continue
            print(f"Tunnel output: {line.strip()}")

            # Check if the URL pattern is found
            if match := url_pattern.search(line):
                tunnel_url = match[0]
                break

        if not tunnel_url:
            print("ERROR: Could not find tunnel URL in output")
            if self.tunnel_process.poll() is None:
                self.tunnel_process.terminate()
            return None

        print(f"Public tunnel created: {tunnel_url}")
        webhook_url = f"{tunnel_url}/webhook"

        # Verify the tunnel works
        try:
            response = requests.get(f"{tunnel_url}/healthcheck", timeout=20)
            if response.status_code == 200:
                print("Tunnel verified: Flask server is accessible")
            else:
                print(f"Tunnel health check returned status {response.status_code}")
                raise RuntimeError("Could not verify Tunnel.") 
        except requests.RequestException as e:
            raise e

        return webhook_url

    def close_tunnel(self) -> None:
        """Close the tunnel if it's running."""
        if self.tunnel_process and self.tunnel_process.poll() is None:
            print("Closing tunnel...")
            self.tunnel_process.terminate()
            self.tunnel_process.wait(timeout=5)
            print("Tunnel closed")

    def wait_for_webhook(self, timeout: int = 120) -> Optional[Dict]:
        """
        Wait for webhook data to be received.

        Args:
            timeout: Maximum time to wait in seconds

        Returns:
            Optional[Dict]: The webhook data or None if timed out
        """
        print(f"Waiting for webhook callback (timeout: {timeout}s)...")

        with self.webhook_received:
            # Wait for notification with timeout
            webhook_received = self.webhook_received.wait(timeout=timeout)

            if webhook_received and self.webhook_data is not None:
                print("Webhook received with data")
                return self.webhook_data

        print(f"Timed out waiting for webhook after {timeout} seconds")
        return None

    def cleanup(self) -> None:
        """Clean up all resources."""
        self.close_tunnel()
        self.shutdown_server()
app = self._create_flask_app() instance-attribute
flask_running = Event() instance-attribute
flask_server_thread = None instance-attribute
port = port instance-attribute
tunnel_process = None instance-attribute
webhook_data = None instance-attribute
webhook_received = Condition() instance-attribute
__init__(port=5050)

Initialize webhook server with configuration.

Parameters:

Name Type Description Default
port int

The port to run the Flask server on

5050
Source code in src/tnh_scholar/utils/webhook_server.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def __init__(self, port: int = 5050):
    """
    Initialize webhook server with configuration.

    Args:
        port: The port to run the Flask server on
    """
    self.port = port
    self.app = self._create_flask_app()
    self.webhook_received = Condition()
    self.webhook_data = None
    self.flask_running = Event()
    self.flask_server_thread = None
    self.tunnel_process = None
cleanup()

Clean up all resources.

Source code in src/tnh_scholar/utils/webhook_server.py
250
251
252
253
def cleanup(self) -> None:
    """Clean up all resources."""
    self.close_tunnel()
    self.shutdown_server()
close_tunnel()

Close the tunnel if it's running.

Source code in src/tnh_scholar/utils/webhook_server.py
219
220
221
222
223
224
225
def close_tunnel(self) -> None:
    """Close the tunnel if it's running."""
    if self.tunnel_process and self.tunnel_process.poll() is None:
        print("Closing tunnel...")
        self.tunnel_process.terminate()
        self.tunnel_process.wait(timeout=5)
        print("Tunnel closed")
create_tunnel()

Create a public webhook URL using py-localtunnel.

Returns:

Type Description
Optional[str]

Optional[str]: The public webhook URL or None if tunnel creation failed

Source code in src/tnh_scholar/utils/webhook_server.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def create_tunnel(self) -> Optional[str]:
    """
    Create a public webhook URL using py-localtunnel.

    Returns:
        Optional[str]: The public webhook URL or None if tunnel creation failed
    """
    # First check if pylt is installed
    try:
        subprocess.run(["pylt", "--version"], check=True, capture_output=True)
    except (subprocess.SubprocessError, FileNotFoundError) as e:
        print("ERROR: pylt not found. Install with: pip install pylt")
        raise RuntimeError("Tunnel not started.") from e

    print(f"Creating public tunnel to port {self.port}...")

    # Start the localtunnel process
    self.tunnel_process = subprocess.Popen(
        ["pylt", "port", str(self.port)],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True
    )

    # Give it time to establish the tunnel
    time.sleep(3)

    # Check if process started successfully
    if self.tunnel_process.poll() is not None:
        if self.tunnel_process.stderr:
            stderr = self.tunnel_process.stderr.read()
        else:
            raise RuntimeError("ERROR: Tunnel process failed: no stderr")
        if self.tunnel_process.stdout:
            stdout = self.tunnel_process.stdout.read()
        else:
            raise RuntimeError("ERROR: Tunnel process failed: no stderr")
        raise RuntimeError(
            f"ERROR: Tunnel process failed:\nSTDOUT: {stdout}\nSTDERR: {stderr}"
            )

    # Regular expression to find the URL in the output
    url_pattern = re.compile(r'https?://[^\s\'"]+')

    # Read output with timeout
    start_time = time.time()
    tunnel_url = None

    while time.time() - start_time < 15:  # Wait up to 15 seconds
        if not self.tunnel_process.stdout:
            raise RuntimeError("ERROR: tunnel process has no stdout.")

        line = self.tunnel_process.stdout.readline()
        if not line:
            time.sleep(0.1)
            continue
        print(f"Tunnel output: {line.strip()}")

        # Check if the URL pattern is found
        if match := url_pattern.search(line):
            tunnel_url = match[0]
            break

    if not tunnel_url:
        print("ERROR: Could not find tunnel URL in output")
        if self.tunnel_process.poll() is None:
            self.tunnel_process.terminate()
        return None

    print(f"Public tunnel created: {tunnel_url}")
    webhook_url = f"{tunnel_url}/webhook"

    # Verify the tunnel works
    try:
        response = requests.get(f"{tunnel_url}/healthcheck", timeout=20)
        if response.status_code == 200:
            print("Tunnel verified: Flask server is accessible")
        else:
            print(f"Tunnel health check returned status {response.status_code}")
            raise RuntimeError("Could not verify Tunnel.") 
    except requests.RequestException as e:
        raise e

    return webhook_url
shutdown_server()

Gracefully shut down the Flask server.

Source code in src/tnh_scholar/utils/webhook_server.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def shutdown_server(self) -> None:
    """Gracefully shut down the Flask server."""
    if not self.flask_running.is_set():
        print("Flask server is not running")
        return

    try:
        print("Shutting down Flask server...")
        requests.post(f"http://localhost:{self.port}/shutdown")

        # Wait for server to stop
        if self.flask_server_thread:
            self.flask_server_thread.join(timeout=5)

        if self.flask_running.is_set():
            print("WARNING: Flask server did not shut down gracefully")
        else:
            print("Flask server shut down successfully")
    except Exception as e:
        print(f"Error shutting down Flask server: {e}")
start_server()

Start Flask server in a separate thread.

Source code in src/tnh_scholar/utils/webhook_server.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def start_server(self) -> None:
    """Start Flask server in a separate thread."""
    # Check if server is already running
    if self.flask_running.is_set() and \
        self.flask_server_thread and \
            self.flask_server_thread.is_alive():
        print(f"Flask server already running on port {self.port}")
        return

    # Reset state
    self.flask_running.clear()

    # Create thread function that sets event when server starts
    def run_server():
        print(f"Starting Flask server on port {self.port}...")
        self.flask_running.set()
        self.app.run(
            host="0.0.0.0", 
            port=self.port, 
            debug=False, 
            use_reloader=False
            )
        self.flask_running.clear()
        print("Flask server has stopped")

    # Start server in a daemon thread
    self.flask_server_thread = Thread(target=run_server, daemon=True)
    self.flask_server_thread.start()

    # Wait for server to start
    if not self.flask_running.wait(timeout=5):
        raise RuntimeError("Flask server failed to start within timeout period")

    print(f"Flask server started successfully on port {self.port}")
wait_for_webhook(timeout=120)

Wait for webhook data to be received.

Parameters:

Name Type Description Default
timeout int

Maximum time to wait in seconds

120

Returns:

Type Description
Optional[Dict]

Optional[Dict]: The webhook data or None if timed out

Source code in src/tnh_scholar/utils/webhook_server.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def wait_for_webhook(self, timeout: int = 120) -> Optional[Dict]:
    """
    Wait for webhook data to be received.

    Args:
        timeout: Maximum time to wait in seconds

    Returns:
        Optional[Dict]: The webhook data or None if timed out
    """
    print(f"Waiting for webhook callback (timeout: {timeout}s)...")

    with self.webhook_received:
        # Wait for notification with timeout
        webhook_received = self.webhook_received.wait(timeout=timeout)

        if webhook_received and self.webhook_data is not None:
            print("Webhook received with data")
            return self.webhook_data

    print(f"Timed out waiting for webhook after {timeout} seconds")
    return None

video_processing

video_processing

video_processing.py

BASE_YDL_OPTIONS = {'quiet': False, 'no_warnings': True, 'extract_flat': True, 'socket_timeout': 30, 'retries': 3, 'ignoreerrors': True, 'logger': logger} module-attribute
DEFAULT_AUDIO_OPTIONS = BASE_YDL_OPTIONS | {'format': 'bestaudio/best', 'postprocessors': [{'key': 'FFmpegExtractAudio', 'preferredcodec': 'mp3', 'preferredquality': '192'}], 'noplaylist': True} module-attribute
DEFAULT_METADATA_FIELDS = ['id', 'title', 'description', 'duration', 'upload_date', 'uploader', 'channel_url', 'webpage_url', 'original_url', 'channel', 'language', 'categories', 'tags'] module-attribute
DEFAULT_METADATA_OPTIONS = BASE_YDL_OPTIONS | {'skip_download': True} module-attribute
DEFAULT_TRANSCRIPT_OPTIONS = BASE_YDL_OPTIONS | {'skip_download': True, 'writesubtitles': True, 'writeautomaticsub': True, 'subtitlesformat': 'ttml'} module-attribute
DEFAULT_VIDEO_OPTIONS = BASE_YDL_OPTIONS | {'format': 'bestvideo+bestaudio/best', 'merge_output_format': 'mp4', 'noplaylist': True} module-attribute
TEMP_FILENAME_FORMAT = 'temp_%(id)s' module-attribute
TEMP_FILENAME_STR = 'temp_{id}' module-attribute
logger = get_child_logger(__name__) module-attribute
DLPDownloader

Bases: YTDownloader

yt-dlp based implementation of YouTube content retrieval.

Assures temporary file export is in the form . where ID is the YouTube video id, and ext is the appropriate extension.

Renames the export file to be based on title and ID by default, or moves the export file to the specified output file with appropriate extension.

Source code in src/tnh_scholar/video_processing/video_processing.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
class DLPDownloader(YTDownloader):
    """
    yt-dlp based implementation of YouTube content retrieval.

    Assures temporary file export is in the form <ID>.<ext> 
    where ID is the YouTube video id, and ext is the appropriate
    extension.

    Renames the export file to be based on title and ID by
    default, or moves the export file to the specified output
    file with appropriate extension.
    """

    def __init__(self, config: Optional[dict] = None):
        self.config = config or BASE_YDL_OPTIONS

    def get_metadata(
        self,
        url: str,
    ) -> Metadata:
        """
        Get metadata for a YouTube video. 
        """
        options = DEFAULT_METADATA_OPTIONS | self.config
        with yt_dlp.YoutubeDL(options) as ydl:
            if info := ydl.extract_info(url):
                return self._extract_metadata(info)
            logger.error(f"Unable to download metadata for {url}.")
            raise DownloadError("No info returned.")

    def get_transcript(
        self,
        url: str,
        lang: str = "en",
        output_path: Optional[Path] = None,
    ) -> VideoTranscript:
        """
        Downloads video transcript in TTML format.

        Args:
            url: YouTube video URL
            lang: Language code for transcript (default: "en")
            output_path: Optional output directory (uses current dir if None)

        Returns:
            TranscriptResource containing TTML file path and metadata

        Raises:
            TranscriptError: If no transcript found for specified language
        """
        temp_path = Path.cwd() / TEMP_FILENAME_FORMAT
        options = DEFAULT_TRANSCRIPT_OPTIONS | self.config | {
            "skip_download": True,
            "subtitleslangs": [lang],
            "outtmpl": str(temp_path),
        }

        with yt_dlp.YoutubeDL(options) as ydl:
            if info := ydl.extract_info(url):
                metadata = self._extract_metadata(info)
                filepath = Path(ydl.prepare_filename(info)).with_suffix(f".{lang}.ttml")
                filepath = self._convert_filename(filepath, metadata, output_path)
                return VideoTranscript(metadata=metadata, filepath=filepath)
            else:
                logger.error("Info not found.")
                raise TranscriptError(f"Transcript not downloaded for {url} in {lang}")

    def get_audio(
        self, 
        url: str, 
        start: Optional[str] = None,
        end: Optional[str] = None,
        output_path: Optional[Path] = None
    ) -> VideoAudio:
        """Download audio and get metadata for a YouTube video."""
        temp_path = Path.cwd() / TEMP_FILENAME_FORMAT
        options = DEFAULT_AUDIO_OPTIONS | self.config | {
            "outtmpl": str(temp_path)
        }

        self._add_start_stop_times(options, start, end)

        with yt_dlp.YoutubeDL(options) as ydl:
            if info := ydl.extract_info(url, download=True):
                metadata = self._extract_metadata(info)
                filepath = Path(ydl.prepare_filename(info)).with_suffix(".mp3")
                filepath = self._convert_filename(filepath, metadata, output_path)
                return VideoAudio(metadata=metadata, filepath=filepath)
            else:
                logger.error("Info not found.")
                raise DownloadError(f"Unable to download {url}.")

    def get_video(
        self,
        url: str,
        quality: Optional[str] = None,
        output_path: Optional[Path] = None
    ) -> VideoFile:
        """
        Download the full video with associated metadata.

        Args:
            url: YouTube video URL
            quality: yt-dlp format string (default: highest available)
            output_path: Optional output directory

        Returns:
            VideoFile containing video file path and metadata

        Raises:
            VideoDownloadError: If download fails
        """
        temp_path = Path.cwd() / TEMP_FILENAME_FORMAT
        video_options = DEFAULT_VIDEO_OPTIONS | self.config | {
            "outtmpl": str(temp_path)
        }
        if quality:
            video_options["format"] = quality

        with yt_dlp.YoutubeDL(video_options) as ydl:
            if info := ydl.extract_info(url, download=True):
                metadata = self._extract_metadata(info)
                ext = info.get("ext", "mp4")
                filepath = Path(ydl.prepare_filename(info)).with_suffix(f".{ext}")
                filepath = self._convert_filename(filepath, metadata, output_path)
                return VideoFile(metadata=metadata, filepath=filepath)
            else:
                logger.error("Info not found.")
                raise VideoDownloadError(f"Unable to download video for {url}.")

    #TODO this function is not affecting the start time of processing. 
    # find a fix or new implementation 
    # (pydub postprocessing after yt-dlp? keep yt-dlp minimal?)        
    def _add_start_stop_times(
        self, options: dict, start: Optional[str], end: Optional[str]) -> None:
        """
        Adds -ss and -to arguments for FFmpegExtractAudio via postprocessor_args dict.
        Modifies options in place.
        """
        if start or end:
            ppa_args = []
            if start:
                ppa_args.extend(["-ss", start])
                logger.debug(f"Added start time to postprocessor args: {start}")
            if end:
                ppa_args.extend(["-to", end])
                logger.debug(f"Added end time to postprocessor args: {end}")

            postprocessor_args = options.setdefault("postprocessor_args", {})

            postprocessor_args.setdefault("ExtractAudio", []).extend(ppa_args)

        logger.info(f"Updated options for postprocessor_args: "
                    f"{options.get('postprocessor_args')}")

    def _extract_metadata(self, info: dict) -> Metadata:
        """Extract standard metadata fields from yt-dlp info."""
        return Metadata.from_fields(info, DEFAULT_METADATA_FIELDS)

    def _show_info(self, info: Metadata) -> None:
        """Debug routine for displaying info."""
        for k in info:
            if data := str(info[k]):
                if len(data) < 200:
                    print(f"{k}: {data}")
                else:
                    print(f"{k}: {data[:200]} ...")

    def get_default_filename_stem(self, metadata: Metadata) -> str:
        """Generate the object download filename."""
        # Expect both id and title in Youtube metadata
        assert metadata["id"]
        assert metadata["title"] 
        video_id = str(metadata["id"])
        sanitized_title = sanitize_filename(str(metadata["title"]))
        return f"{sanitized_title}_{video_id}"

    def get_default_export_name(self, url) -> str:
        """Get default export filename for a URL."""
        metadata = self.get_metadata(url)
        return self.get_default_filename_stem(metadata)

    def _convert_filename(
        self, 
        temp_path: Path,
        metadata: Metadata, 
        output_path: Optional[Path]
    ) -> Path:
        """
        Move/rename file from temp_path to output_path if specified.
        If output_path is not provided, a sanitized title and video ID are 
        used to create the new filename. This function is required because yt-dlp 
        is not consistent in its output file naming (across subtitles, audio, metadata)
        In this interface implementation we use a temp_path and TEMP_FILENAME_FORMAT to 
        specify the the temporary output to be the video_id followed by the correct 
        extension for all resources. This function then converts the temp_path
        to the appropriately named resource, using output_path if specified,
        or a default filename format ({sanitized_title}_{id}).
        """
        video_id = str(metadata["id"])
        if video_id not in str(temp_path):
            raise VideoProcessingError(f"Temporary path '{temp_path}' "
                                       "does not contain video ID '{video_id}'.")
        if not temp_path.suffix:
            raise VideoProcessingError(f"Temporary path '{temp_path}' "
                                       "does not have a file extension.")

        if not output_path:
            new_filename = self.get_default_filename_stem(metadata)
            new_path = Path(str(temp_path).replace(
                TEMP_FILENAME_STR.format(id=video_id), new_filename
                )
            )
            logger.debug(f"Renaming downloaded YT resource to: {new_path}")
            return temp_path.rename(new_path)

        if not output_path.suffix:
            output_path = output_path.with_suffix(temp_path.suffix)
            logger.info(f"Added extension {temp_path.suffix} to output path")
        elif output_path.suffix != temp_path.suffix:
            output_path = output_path.with_suffix(temp_path.suffix)
            logger.warning(f"Replaced output extension with {temp_path.suffix}")
        return temp_path.rename(output_path)
config = config or BASE_YDL_OPTIONS instance-attribute
__init__(config=None)
Source code in src/tnh_scholar/video_processing/video_processing.py
178
179
def __init__(self, config: Optional[dict] = None):
    self.config = config or BASE_YDL_OPTIONS
get_audio(url, start=None, end=None, output_path=None)

Download audio and get metadata for a YouTube video.

Source code in src/tnh_scholar/video_processing/video_processing.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def get_audio(
    self, 
    url: str, 
    start: Optional[str] = None,
    end: Optional[str] = None,
    output_path: Optional[Path] = None
) -> VideoAudio:
    """Download audio and get metadata for a YouTube video."""
    temp_path = Path.cwd() / TEMP_FILENAME_FORMAT
    options = DEFAULT_AUDIO_OPTIONS | self.config | {
        "outtmpl": str(temp_path)
    }

    self._add_start_stop_times(options, start, end)

    with yt_dlp.YoutubeDL(options) as ydl:
        if info := ydl.extract_info(url, download=True):
            metadata = self._extract_metadata(info)
            filepath = Path(ydl.prepare_filename(info)).with_suffix(".mp3")
            filepath = self._convert_filename(filepath, metadata, output_path)
            return VideoAudio(metadata=metadata, filepath=filepath)
        else:
            logger.error("Info not found.")
            raise DownloadError(f"Unable to download {url}.")
get_default_export_name(url)

Get default export filename for a URL.

Source code in src/tnh_scholar/video_processing/video_processing.py
342
343
344
345
def get_default_export_name(self, url) -> str:
    """Get default export filename for a URL."""
    metadata = self.get_metadata(url)
    return self.get_default_filename_stem(metadata)
get_default_filename_stem(metadata)

Generate the object download filename.

Source code in src/tnh_scholar/video_processing/video_processing.py
333
334
335
336
337
338
339
340
def get_default_filename_stem(self, metadata: Metadata) -> str:
    """Generate the object download filename."""
    # Expect both id and title in Youtube metadata
    assert metadata["id"]
    assert metadata["title"] 
    video_id = str(metadata["id"])
    sanitized_title = sanitize_filename(str(metadata["title"]))
    return f"{sanitized_title}_{video_id}"
get_metadata(url)

Get metadata for a YouTube video.

Source code in src/tnh_scholar/video_processing/video_processing.py
181
182
183
184
185
186
187
188
189
190
191
192
193
def get_metadata(
    self,
    url: str,
) -> Metadata:
    """
    Get metadata for a YouTube video. 
    """
    options = DEFAULT_METADATA_OPTIONS | self.config
    with yt_dlp.YoutubeDL(options) as ydl:
        if info := ydl.extract_info(url):
            return self._extract_metadata(info)
        logger.error(f"Unable to download metadata for {url}.")
        raise DownloadError("No info returned.")
get_transcript(url, lang='en', output_path=None)

Downloads video transcript in TTML format.

Parameters:

Name Type Description Default
url str

YouTube video URL

required
lang str

Language code for transcript (default: "en")

'en'
output_path Optional[Path]

Optional output directory (uses current dir if None)

None

Returns:

Type Description
VideoTranscript

TranscriptResource containing TTML file path and metadata

Raises:

Type Description
TranscriptError

If no transcript found for specified language

Source code in src/tnh_scholar/video_processing/video_processing.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def get_transcript(
    self,
    url: str,
    lang: str = "en",
    output_path: Optional[Path] = None,
) -> VideoTranscript:
    """
    Downloads video transcript in TTML format.

    Args:
        url: YouTube video URL
        lang: Language code for transcript (default: "en")
        output_path: Optional output directory (uses current dir if None)

    Returns:
        TranscriptResource containing TTML file path and metadata

    Raises:
        TranscriptError: If no transcript found for specified language
    """
    temp_path = Path.cwd() / TEMP_FILENAME_FORMAT
    options = DEFAULT_TRANSCRIPT_OPTIONS | self.config | {
        "skip_download": True,
        "subtitleslangs": [lang],
        "outtmpl": str(temp_path),
    }

    with yt_dlp.YoutubeDL(options) as ydl:
        if info := ydl.extract_info(url):
            metadata = self._extract_metadata(info)
            filepath = Path(ydl.prepare_filename(info)).with_suffix(f".{lang}.ttml")
            filepath = self._convert_filename(filepath, metadata, output_path)
            return VideoTranscript(metadata=metadata, filepath=filepath)
        else:
            logger.error("Info not found.")
            raise TranscriptError(f"Transcript not downloaded for {url} in {lang}")
get_video(url, quality=None, output_path=None)

Download the full video with associated metadata.

Parameters:

Name Type Description Default
url str

YouTube video URL

required
quality Optional[str]

yt-dlp format string (default: highest available)

None
output_path Optional[Path]

Optional output directory

None

Returns:

Type Description
VideoFile

VideoFile containing video file path and metadata

Raises:

Type Description
VideoDownloadError

If download fails

Source code in src/tnh_scholar/video_processing/video_processing.py
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def get_video(
    self,
    url: str,
    quality: Optional[str] = None,
    output_path: Optional[Path] = None
) -> VideoFile:
    """
    Download the full video with associated metadata.

    Args:
        url: YouTube video URL
        quality: yt-dlp format string (default: highest available)
        output_path: Optional output directory

    Returns:
        VideoFile containing video file path and metadata

    Raises:
        VideoDownloadError: If download fails
    """
    temp_path = Path.cwd() / TEMP_FILENAME_FORMAT
    video_options = DEFAULT_VIDEO_OPTIONS | self.config | {
        "outtmpl": str(temp_path)
    }
    if quality:
        video_options["format"] = quality

    with yt_dlp.YoutubeDL(video_options) as ydl:
        if info := ydl.extract_info(url, download=True):
            metadata = self._extract_metadata(info)
            ext = info.get("ext", "mp4")
            filepath = Path(ydl.prepare_filename(info)).with_suffix(f".{ext}")
            filepath = self._convert_filename(filepath, metadata, output_path)
            return VideoFile(metadata=metadata, filepath=filepath)
        else:
            logger.error("Info not found.")
            raise VideoDownloadError(f"Unable to download video for {url}.")
DownloadError

Bases: VideoProcessingError

Raised for download-related errors.

Source code in src/tnh_scholar/video_processing/video_processing.py
89
90
91
class DownloadError(VideoProcessingError):
    """Raised for download-related errors."""
    pass
TranscriptError

Bases: VideoProcessingError

Raised for transcript-related errors.

Source code in src/tnh_scholar/video_processing/video_processing.py
85
86
87
class TranscriptError(VideoProcessingError):
    """Raised for transcript-related errors."""
    pass
VideoAudio dataclass

Bases: VideoResource

Source code in src/tnh_scholar/video_processing/video_processing.py
106
107
class VideoAudio(VideoResource): 
    pass 
VideoDownloadError

Bases: VideoProcessingError

Raised for video download-related errors.

Source code in src/tnh_scholar/video_processing/video_processing.py
93
94
95
class VideoDownloadError(VideoProcessingError):
    """Raised for video download-related errors."""
    pass
VideoFile dataclass

Bases: VideoResource

Represents a downloaded video file and its metadata.

Source code in src/tnh_scholar/video_processing/video_processing.py
109
110
111
class VideoFile(VideoResource):
    """Represents a downloaded video file and its metadata."""
    pass
VideoProcessingError

Bases: Exception

Base exception for video processing errors.

Source code in src/tnh_scholar/video_processing/video_processing.py
81
82
83
class VideoProcessingError(Exception):
    """Base exception for video processing errors."""
    pass
VideoResource dataclass

Base class for all video resources.

Source code in src/tnh_scholar/video_processing/video_processing.py
 97
 98
 99
100
101
@dataclass 
class VideoResource:
    """Base class for all video resources."""
    metadata: Metadata
    filepath: Optional[Path] = None
filepath = None class-attribute instance-attribute
metadata instance-attribute
__init__(metadata, filepath=None)
VideoTranscript dataclass

Bases: VideoResource

Source code in src/tnh_scholar/video_processing/video_processing.py
103
104
class VideoTranscript(VideoResource): 
    pass
YTDownloader

Abstract base class for YouTube content retrieval.

Source code in src/tnh_scholar/video_processing/video_processing.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
class YTDownloader:
    """Abstract base class for YouTube content retrieval."""

    def get_transcript(
        self, 
        url: str, 
        lang: str = "en", 
        output_path: Optional[Path] = None
    ) -> VideoTranscript:
        """Retrieve video transcript with associated metadata."""
        raise NotImplementedError

    def get_audio(
        self, 
        url: str, 
        start: str,
        end: str,
        output_path: Optional[Path]
    ) -> VideoAudio:
        """Extract audio with associated metadata."""
        raise NotImplementedError

    def get_metadata(
        self, 
        url: str, 
    ) -> Metadata:
        """Retrieve video metadata only."""
        raise NotImplementedError

    def get_video(
        self,
        url: str,
        quality: Optional[str] = None,
        output_path: Optional[Path] = None
    ) -> VideoFile:
        """
        Download the full video with associated metadata.

        Args:
            url: YouTube video URL
            quality: yt-dlp format string (default: highest available)
            output_path: Optional output directory

        Returns:
            VideoFile containing video file path and metadata

        Raises:
            VideoDownloadError: If download fails
        """
        raise NotImplementedError
get_audio(url, start, end, output_path)

Extract audio with associated metadata.

Source code in src/tnh_scholar/video_processing/video_processing.py
126
127
128
129
130
131
132
133
134
def get_audio(
    self, 
    url: str, 
    start: str,
    end: str,
    output_path: Optional[Path]
) -> VideoAudio:
    """Extract audio with associated metadata."""
    raise NotImplementedError
get_metadata(url)

Retrieve video metadata only.

Source code in src/tnh_scholar/video_processing/video_processing.py
136
137
138
139
140
141
def get_metadata(
    self, 
    url: str, 
) -> Metadata:
    """Retrieve video metadata only."""
    raise NotImplementedError
get_transcript(url, lang='en', output_path=None)

Retrieve video transcript with associated metadata.

Source code in src/tnh_scholar/video_processing/video_processing.py
117
118
119
120
121
122
123
124
def get_transcript(
    self, 
    url: str, 
    lang: str = "en", 
    output_path: Optional[Path] = None
) -> VideoTranscript:
    """Retrieve video transcript with associated metadata."""
    raise NotImplementedError
get_video(url, quality=None, output_path=None)

Download the full video with associated metadata.

Parameters:

Name Type Description Default
url str

YouTube video URL

required
quality Optional[str]

yt-dlp format string (default: highest available)

None
output_path Optional[Path]

Optional output directory

None

Returns:

Type Description
VideoFile

VideoFile containing video file path and metadata

Raises:

Type Description
VideoDownloadError

If download fails

Source code in src/tnh_scholar/video_processing/video_processing.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def get_video(
    self,
    url: str,
    quality: Optional[str] = None,
    output_path: Optional[Path] = None
) -> VideoFile:
    """
    Download the full video with associated metadata.

    Args:
        url: YouTube video URL
        quality: yt-dlp format string (default: highest available)
        output_path: Optional output directory

    Returns:
        VideoFile containing video file path and metadata

    Raises:
        VideoDownloadError: If download fails
    """
    raise NotImplementedError
extract_text_from_ttml(ttml_path)

Extract plain text content from TTML file.

Parameters:

Name Type Description Default
ttml_path Path

Path to TTML transcript file

required

Returns:

Type Description
str

Plain text content with one sentence per line

Raises:

Type Description
ValueError

If file doesn't exist or has invalid content

Source code in src/tnh_scholar/video_processing/video_processing.py
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
def extract_text_from_ttml(ttml_path: Path) -> str:
    """Extract plain text content from TTML file.

    Args:
        ttml_path: Path to TTML transcript file

    Returns:
        Plain text content with one sentence per line

    Raises:
        ValueError: If file doesn't exist or has invalid content
    """
    if not ttml_path.exists():
        raise ValueError(f"TTML file not found: {ttml_path}")

    ttml_str = ttml_path.read_text()
    if not ttml_str.strip():
        return ""

    namespaces = {
        "tt": "http://www.w3.org/ns/ttml",
        "tts": "http://www.w3.org/ns/ttml#styling",
    }

    try:
        root = ET.fromstring(ttml_str)
        text_lines = []
        for p in root.findall(".//tt:p", namespaces):
            if p.text is not None:
                text_lines.append(p.text.strip())
            else:
                text_lines.append("")
                logger.debug("Found empty paragraph in TTML, preserving as blank line")

        logger.info(f"Extracted {len(text_lines)} lines of text from TTML")
        return "\n".join(text_lines)

    except ParseError as e:
        logger.error(f"Failed to parse XML content: {e}")
        raise
get_youtube_urls_from_csv(file_path)

Reads a CSV file containing YouTube URLs and titles, logs the titles, and returns a list of URLs.

Parameters:

Name Type Description Default
file_path Path

Path to the CSV file containing YouTube URLs and titles.

required

Returns:

Type Description
List[str]

List[str]: List of YouTube URLs.

Raises:

Type Description
FileNotFoundError

If the file does not exist.

ValueError

If the CSV file is improperly formatted.

Source code in src/tnh_scholar/video_processing/video_processing.py
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
def get_youtube_urls_from_csv(file_path: Path) -> List[str]:
    """
    Reads a CSV file containing YouTube URLs and titles, logs the titles,
    and returns a list of URLs.

    Args:
        file_path (Path): Path to the CSV file containing YouTube URLs and titles.

    Returns:
        List[str]: List of YouTube URLs.

    Raises:
        FileNotFoundError: If the file does not exist.
        ValueError: If the CSV file is improperly formatted.
    """
    if not file_path.exists():
        logger.error(f"File not found: {file_path}")
        raise FileNotFoundError(f"File not found: {file_path}")

    urls = []

    try:
        with file_path.open("r", encoding="utf-8") as f:
            reader = csv.DictReader(f)

            if (reader.fieldnames is None 
                or "url" not in reader.fieldnames 
                or "title" not in reader.fieldnames
            ):
                logger.error("CSV file must contain 'url' and 'title' columns.")
                raise ValueError("CSV file must contain 'url' and 'title' columns.")

            for row in reader:
                url = row["url"]
                title = row["title"]
                urls.append(url)
                logger.info(f"Found video title: {title}")
    except Exception as e:
        logger.exception(f"Error processing CSV file: {e}")
        raise

    return urls

video_processing_old1

DEFAULT_TRANSCRIPT_DIR = Path.home() / '.yt_dlp_transcripts' module-attribute
DEFAULT_TRANSCRIPT_OPTIONS = {'skip_download': True, 'quiet': True, 'no_warnings': True, 'extract_flat': True, 'socket_timeout': 30, 'retries': 3, 'ignoreerrors': True, 'logger': logger} module-attribute
logger = get_child_logger(__name__) module-attribute
SubtitleTrack

Bases: TypedDict

Type definition for a subtitle track entry.

Source code in src/tnh_scholar/video_processing/video_processing_old1.py
55
56
57
58
59
60
class SubtitleTrack(TypedDict):
    """Type definition for a subtitle track entry."""

    url: str
    ext: str
    name: str
ext instance-attribute
name instance-attribute
url instance-attribute
TranscriptNotFoundError

Bases: Exception

Raised when no transcript is available for the requested language.

Source code in src/tnh_scholar/video_processing/video_processing_old1.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class TranscriptNotFoundError(Exception):
    """Raised when no transcript is available for the requested language."""

    def __init__(
        self,
        video_url: str,
        language: str,
    ) -> None:
        """
        Initialize TranscriptNotFoundError.

        Args:
            video_url: URL of the video where transcript was not found
            language: Language code that was requested
            available_manual: List of available manual transcript languages
            available_auto: List of available auto-generated transcript languages
        """
        self.video_url = video_url
        self.language = language

        message = (
            f"No transcript found for {self.video_url} in language {self.language}. "
        )
        super().__init__(message)
language = language instance-attribute
video_url = video_url instance-attribute
__init__(video_url, language)

Initialize TranscriptNotFoundError.

Parameters:

Name Type Description Default
video_url str

URL of the video where transcript was not found

required
language str

Language code that was requested

required
available_manual

List of available manual transcript languages

required
available_auto

List of available auto-generated transcript languages

required
Source code in src/tnh_scholar/video_processing/video_processing_old1.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    video_url: str,
    language: str,
) -> None:
    """
    Initialize TranscriptNotFoundError.

    Args:
        video_url: URL of the video where transcript was not found
        language: Language code that was requested
        available_manual: List of available manual transcript languages
        available_auto: List of available auto-generated transcript languages
    """
    self.video_url = video_url
    self.language = language

    message = (
        f"No transcript found for {self.video_url} in language {self.language}. "
    )
    super().__init__(message)
VideoInfo

Bases: TypedDict

Type definition for relevant video info fields.

Source code in src/tnh_scholar/video_processing/video_processing_old1.py
63
64
65
66
67
class VideoInfo(TypedDict):
    """Type definition for relevant video info fields."""

    subtitles: Dict[str, List[SubtitleTrack]]
    automatic_captions: Dict[str, List[SubtitleTrack]]
automatic_captions instance-attribute
subtitles instance-attribute
download_audio_yt(url, output_dir, start_time=None, prompt_overwrite=True)

Downloads audio from a YouTube video using yt_dlp.YoutubeDL, with an optional start time.

Parameters:

Name Type Description Default
url str

URL of the YouTube video.

required
output_dir Path

Directory to save the downloaded audio file.

required
start_time str

Optional start time (e.g., '00:01:30' for 1 minute 30 seconds).

None

Returns:

Name Type Description
Path Path

Path to the downloaded audio file.

Source code in src/tnh_scholar/video_processing/video_processing_old1.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def download_audio_yt(
    url: str, output_dir: Path, start_time: Optional[str] = None, prompt_overwrite=True
) -> Path:
    """
    Downloads audio from a YouTube video using yt_dlp.YoutubeDL, with an optional start time.

    Args:
        url (str): URL of the YouTube video.
        output_dir (Path): Directory to save the downloaded audio file.
        start_time (str): Optional start time (e.g., '00:01:30' for 1 minute 30 seconds).

    Returns:
        Path: Path to the downloaded audio file.
    """
    output_dir.mkdir(parents=True, exist_ok=True)
    ydl_opts = {
        "format": "bestaudio/best",
        "postprocessors": [
            {
                "key": "FFmpegExtractAudio",
                "preferredcodec": "mp3",
                "preferredquality": "192",
            }
        ],
        "postprocessor_args": [],
        "noplaylist": True,
        "outtmpl": str(output_dir / "%(title)s.%(ext)s"),
    }

    # Add start time to the FFmpeg postprocessor if provided
    if start_time:
        ydl_opts["postprocessor_args"].extend(["-ss", start_time])
        logger.info(f"Postprocessor start time set to: {start_time}")

    with yt_dlp.YoutubeDL(ydl_opts) as ydl:
        info = ydl.extract_info(url, download=True)  # Extract metadata and download
        filename = ydl.prepare_filename(info)
        return Path(filename).with_suffix(".mp3")
get_transcript(url, lang='en', download_dir=DEFAULT_TRANSCRIPT_DIR, keep_transcript_file=False)

Downloads and extracts the transcript for a given YouTube video URL.

Retrieves the transcript file, extracts the text content, and returns the raw text.

Parameters:

Name Type Description Default
url str

The URL of the YouTube video.

required
lang str

The language code for the transcript (default: 'en').

'en'
download_dir Path

The directory to download the transcript to.

DEFAULT_TRANSCRIPT_DIR
keep_transcript_file bool

Whether to keep the downloaded transcript file (default: False).

False

Returns:

Type Description
str

The extracted transcript text.

Raises:

Type Description
TranscriptNotFoundError

If no transcript is available in the specified language.

DownloadError

If video info extraction or download fails.

ValueError

If the downloaded transcript file is invalid or empty.

ParseError

If XML parsing of the transcript fails.

Source code in src/tnh_scholar/video_processing/video_processing_old1.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def get_transcript(
    url: str,
    lang: str = "en",
    download_dir: Path = DEFAULT_TRANSCRIPT_DIR,
    keep_transcript_file: bool = False,
) -> str:
    """Downloads and extracts the transcript for a given YouTube video URL.

    Retrieves the transcript file, extracts the text content, and returns the raw text.

    Args:
        url: The URL of the YouTube video.
        lang: The language code for the transcript (default: 'en').
        download_dir: The directory to download the transcript to.
        keep_transcript_file: Whether to keep the downloaded transcript file (default: False).

    Returns:
        The extracted transcript text.

    Raises:
        TranscriptNotFoundError: If no transcript is available in the specified language.
        yt_dlp.utils.DownloadError: If video info extraction or download fails.
        ValueError: If the downloaded transcript file is invalid or empty.
        ParseError: If XML parsing of the transcript fails.
    """

    transcript_file = _download_yt_ttml(download_dir, url=url, lang=lang)

    text = read_str_from_file(transcript_file)

    if not keep_transcript_file:
        try:
            os.remove(transcript_file)
            logger.debug(f"Removed temporary transcript file: {transcript_file}")
        except OSError as e:
            logger.warning(
                f"Failed to remove temporary transcript file {transcript_file}: {e}"
            )

    return _extract_ttml_text(text)
get_transcript_info(video_url, lang='en')

Retrieves the transcript URL for a video in the specified language.

Parameters:

Name Type Description Default
video_url str

The URL of the video

required
lang str

The desired language code

'en'

Returns:

Type Description

URL of the transcript

Raises:

Type Description
TranscriptNotFoundError

If no transcript is available in the specified language

DownloadError

If video info extraction fails

Source code in src/tnh_scholar/video_processing/video_processing_old1.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
def get_transcript_info(video_url: str, lang: str = "en"):
    """
    Retrieves the transcript URL for a video in the specified language.

    Args:
        video_url: The URL of the video
        lang: The desired language code

    Returns:
        URL of the transcript

    Raises:
        TranscriptNotFoundError: If no transcript is available in the specified language
        yt_dlp.utils.DownloadError: If video info extraction fails
    """
    options = {
        "writesubtitles": True,
        "writeautomaticsub": True,
        "subtitleslangs": [lang],
        "skip_download": True,
        #    'verbose': True
    }

    with yt_dlp.YoutubeDL(options) as ydl:
        # This may raise yt_dlp.utils.DownloadError which we let propagate
        info: VideoInfo = ydl.extract_info(video_url, download=False)  # type: ignore

        subtitles = info.get("subtitles", {})
        auto_subtitles = info.get("automatic_captions", {})

        # Log available subtitle information
        logger.debug("Available subtitles:")
        logger.debug(f"Manual subtitles: {list(subtitles.keys())}")
        logger.debug(f"Auto captions: {list(auto_subtitles.keys())}")

        if lang in subtitles:
            return subtitles[lang][0]["url"]
        elif lang in auto_subtitles:
            return auto_subtitles[lang][0]["url"]

        raise TranscriptNotFoundError(video_url=video_url, language=lang)
get_video_download_path_yt(output_dir, url)

Extracts the video title using yt-dlp.

Parameters:

Name Type Description Default
url str

The YouTube URL.

required

Returns:

Name Type Description
str Path

The title of the video.

Source code in src/tnh_scholar/video_processing/video_processing_old1.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def get_video_download_path_yt(output_dir: Path, url: str) -> Path:
    """
    Extracts the video title using yt-dlp.

    Args:
        url (str): The YouTube URL.

    Returns:
        str: The title of the video.
    """
    ydl_opts = {
        "quiet": True,  # Suppress output
        "skip_download": True,  # Don't download, just fetch metadata
        "outtmpl": str(output_dir / "%(title)s.%(ext)s"),
    }

    with yt_dlp.YoutubeDL(ydl_opts) as ydl:
        info = ydl.extract_info(
            url, download=False
        )  # Extract metadata without downloading
        filepath = ydl.prepare_filename(info)

    return Path(filepath).with_suffix(".mp3")
get_youtube_urls_from_csv(file_path)

Reads a CSV file containing YouTube URLs and titles, logs the titles, and returns a list of URLs.

Parameters:

Name Type Description Default
file_path Path

Path to the CSV file containing YouTube URLs and titles.

required

Returns:

Type Description
List[str]

List[str]: List of YouTube URLs.

Raises:

Type Description
FileNotFoundError

If the file does not exist.

ValueError

If the CSV file is improperly formatted.

Source code in src/tnh_scholar/video_processing/video_processing_old1.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def get_youtube_urls_from_csv(file_path: Path) -> List[str]:
    """
    Reads a CSV file containing YouTube URLs and titles, logs the titles,
    and returns a list of URLs.

    Args:
        file_path (Path): Path to the CSV file containing YouTube URLs and titles.

    Returns:
        List[str]: List of YouTube URLs.

    Raises:
        FileNotFoundError: If the file does not exist.
        ValueError: If the CSV file is improperly formatted.
    """
    if not file_path.exists():
        logger.error(f"File not found: {file_path}")
        raise FileNotFoundError(f"File not found: {file_path}")

    urls = []

    try:
        with file_path.open("r", encoding="utf-8") as f:
            reader = csv.DictReader(f)

            if reader.fieldnames is None or "url" not in reader.fieldnames or "title" not in reader.fieldnames:
                logger.error("CSV file must contain 'url' and 'title' columns.")
                raise ValueError("CSV file must contain 'url' and 'title' columns.")

            for row in reader:
                url = row["url"]
                title = row["title"]
                urls.append(url)
                logger.info(f"Found video title: {title}")
    except Exception as e:
        logger.exception(f"Error processing CSV file: {e}")
        raise

    return urls

video_processing_old2

AUDIO_DOWNLOAD_OPTIONS = BASE_YDL_OPTIONS | {'format': 'bestaudio/best', 'postprocessors': [{'key': 'FFmpegExtractAudio', 'preferredcodec': 'mp3', 'preferredquality': '192'}], 'noplaylist': True} module-attribute
BASE_YDL_OPTIONS = {'quiet': True, 'no_warnings': True, 'extract_flat': True, 'socket_timeout': 30, 'retries': 3, 'ignoreerrors': True, 'logger': logger} module-attribute
DEFAULT_METADATA_FIELDS = ['id', 'title', 'description', 'duration', 'upload_date', 'uploader', 'channel_url', 'webpage_url', 'original_url', 'channel', 'language', 'categories', 'tags'] module-attribute
DEFAULT_TRANSCRIPT_DIR = Path.home() / '.yt_dlp_transcripts' module-attribute
TRANSCRIPT_OPTIONS = BASE_YDL_OPTIONS | {'writesubtitles': True, 'writeautomaticsub': True, 'subtitlesformat': 'ttml'} module-attribute
logger = get_child_logger(__name__) module-attribute
SubtitleTrack

Bases: TypedDict

Type definition for a subtitle track entry.

Source code in src/tnh_scholar/video_processing/video_processing_old2.py
81
82
83
84
85
class SubtitleTrack(TypedDict):
    """Type definition for a subtitle track entry."""
    url: str
    ext: str
    name: str
ext instance-attribute
name instance-attribute
url instance-attribute
TranscriptNotFoundError

Bases: Exception

Raised when no transcript is available for the requested language.

Source code in src/tnh_scholar/video_processing/video_processing_old2.py
92
93
94
95
96
97
98
99
class TranscriptNotFoundError(Exception):
    """Raised when no transcript is available for the requested language."""
    def __init__(self, video_url: str, language: str) -> None:
        self.video_url = video_url
        self.language = language
        message = f"No transcript found for {self.video_url} \
                    in language {self.language}."
        super().__init__(message)
language = language instance-attribute
video_url = video_url instance-attribute
__init__(video_url, language)
Source code in src/tnh_scholar/video_processing/video_processing_old2.py
94
95
96
97
98
99
def __init__(self, video_url: str, language: str) -> None:
    self.video_url = video_url
    self.language = language
    message = f"No transcript found for {self.video_url} \
                in language {self.language}."
    super().__init__(message)
VideoDownload dataclass

Bases: VideoMetadata

Result of download operations.

Source code in src/tnh_scholar/video_processing/video_processing_old2.py
76
77
78
79
@dataclass
class VideoDownload(VideoMetadata):
    """Result of download operations."""
    filepath: Path
filepath instance-attribute
__init__(metadata, filepath)
VideoInfo

Bases: TypedDict

Type definition for relevant video info fields.

Source code in src/tnh_scholar/video_processing/video_processing_old2.py
87
88
89
90
class VideoInfo(TypedDict):
    """Type definition for relevant video info fields."""
    subtitles: Dict[str, List[SubtitleTrack]]
    automatic_captions: Dict[str, List[SubtitleTrack]]
automatic_captions instance-attribute
subtitles instance-attribute
VideoMetadata dataclass

Base class for video operations containing common metadata.

Source code in src/tnh_scholar/video_processing/video_processing_old2.py
66
67
68
69
@dataclass
class VideoMetadata:
    """Base class for video operations containing common metadata."""
    metadata: Dict[str, Any]
metadata instance-attribute
__init__(metadata)
VideoTranscript dataclass

Bases: VideoMetadata

Result of transcript operations.

Source code in src/tnh_scholar/video_processing/video_processing_old2.py
71
72
73
74
@dataclass
class VideoTranscript(VideoMetadata):
    """Result of transcript operations."""
    content: str
content instance-attribute
__init__(metadata, content)
download_audio_yt(url, output_dir, start_time=None)

Downloads audio from YouTube URL with optional start time.

Source code in src/tnh_scholar/video_processing/video_processing_old2.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def download_audio_yt(
    url: str, 
    output_dir: Path, 
    start_time: Optional[str] = None
) -> VideoDownload:
    """Downloads audio from YouTube URL with optional start time."""
    output_dir.mkdir(parents=True, exist_ok=True)

    options = AUDIO_DOWNLOAD_OPTIONS | {
        "outtmpl": str(output_dir / "%(title)s.%(ext)s"),
    }

    if start_time:
        options["postprocessor_args"] = ["-ss", start_time]
        logger.info(f"Postprocessor start time set to: {start_time}")

    with yt_dlp.YoutubeDL(options) as ydl:
        if info := ydl.extract_info(url, download=True):
            filepath = Path(ydl.prepare_filename(info)).with_suffix(".mp3")
            metadata = _extract_metadata(info)
        else:
            logger.error(f"YT audio download: Unable to get info for {url}.")
            raise 
        return VideoDownload(metadata=metadata, filepath=filepath)
get_transcript(url, lang='en', download_dir=DEFAULT_TRANSCRIPT_DIR, keep_transcript_file=False)

Downloads and extracts transcript with metadata.

Source code in src/tnh_scholar/video_processing/video_processing_old2.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def get_transcript(
    url: str,
    lang: str = "en",
    download_dir: Path = DEFAULT_TRANSCRIPT_DIR,
    keep_transcript_file: bool = False,
) -> VideoTranscript:
    """Downloads and extracts transcript with metadata."""
    transcript_file = _download_yt_ttml(download_dir, url=url, lang=lang)
    text = read_str_from_file(transcript_file)

    if not keep_transcript_file:
        try:
            os.remove(transcript_file)
            logger.debug(f"Removed temporary transcript file: {transcript_file}")
        except OSError as e:
            logger.warning(
                f"Failed to remove temporary transcript file {transcript_file}: {e}"
                )

    content = _extract_ttml_text(text)

    # Get metadata
    options = BASE_YDL_OPTIONS | {"skip_download": True}
    with yt_dlp.YoutubeDL(options) as ydl:
        if info := ydl.extract_info(url, download=False):
            metadata = _extract_metadata(info)
        else:
            logger.error(f"YT get transcript: unable to get info for {url}.")
            raise
    return VideoTranscript(metadata=metadata, content=content)
get_video_download_path_yt(output_dir, url)

Get video metadata and expected download path.

Source code in src/tnh_scholar/video_processing/video_processing_old2.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def get_video_download_path_yt(output_dir: Path, url: str) -> VideoDownload:
    """Get video metadata and expected download path."""
    options = AUDIO_DOWNLOAD_OPTIONS | {
        "skip_download": True,
        "outtmpl": str(output_dir / "%(title)s.%(ext)s"),
    }

    with yt_dlp.YoutubeDL(options) as ydl:
        if info := ydl.extract_info(url, download=False):
            filepath = Path(ydl.prepare_filename(info)).with_suffix(".mp3")
            metadata = _extract_metadata(info)
        else:
            logger.error(f"YT video download: unable to extract info for {url}")
            raise
        return VideoDownload(metadata=metadata, filepath=filepath)
get_video_metadata(url)

Get metadata for a YouTube video without downloading content.

Parameters:

Name Type Description Default
url str

YouTube video URL

required

Returns:

Type Description
VideoResult

VideoResult with only metadata field populated

Raises:

Type Description
DownloadError

If video info extraction fails

Source code in src/tnh_scholar/video_processing/video_processing_old2.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def get_video_metadata(url: str) -> VideoResult:
    """Get metadata for a YouTube video without downloading content.

    Args:
        url: YouTube video URL

    Returns:
        VideoResult with only metadata field populated

    Raises:
        yt_dlp.utils.DownloadError: If video info extraction fails
    """
    options = BASE_YDL_OPTIONS | {"skip_download": True}

    with yt_dlp.YoutubeDL(options) as ydl:
        info = ydl.extract_info(url, download=False)
        metadata = _extract_metadata(info)
        return VideoResult(metadata=metadata)
get_youtube_urls_from_csv(file_path)

Reads YouTube URLs from a CSV file containing URLs and titles.

Source code in src/tnh_scholar/video_processing/video_processing_old2.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def get_youtube_urls_from_csv(file_path: Path) -> List[str]:
    """Reads YouTube URLs from a CSV file containing URLs and titles."""
    if not file_path.exists():
        logger.error(f"File not found: {file_path}")
        raise FileNotFoundError(f"File not found: {file_path}")

    urls = []
    try:
        with file_path.open("r", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            if reader.fieldnames is None \
                or "url" not in reader.fieldnames \
                or "title" not in reader.fieldnames:
                logger.error("CSV file must contain 'url' and 'title' columns.")
                raise ValueError("CSV file must contain 'url' and 'title' columns.")

            for row in reader:
                urls.append(row["url"])
                logger.info(f"Found video title: {row['title']}")
    except Exception as e:
        logger.exception(f"Error processing CSV file: {e}")
        raise

    return urls

yt_transcribe

DEFAULT_CHUNK_DURATION_MS = 10 * 60 * 1000 module-attribute
DEFAULT_CHUNK_DURATION_S = 10 * 60 module-attribute
DEFAULT_OUTPUT_DIR = './video_transcriptions' module-attribute
DEFAULT_PROMPT = 'Dharma, Deer Park, Thay, Thich Nhat Hanh, Bodhicitta, Bodhisattva, Mahayana' module-attribute
EXPECTED_ENV = 'tnh-scholar' module-attribute
args = parser.parse_args() module-attribute
group = parser.add_mutually_exclusive_group(required=True) module-attribute
logger = get_child_logger('yt_transcribe') module-attribute
output_directory = Path(args.output_dir) module-attribute
parser = argparse.ArgumentParser(description='Transcribe YouTube videos from a URL or a file containing URLs.') module-attribute
url_file = Path(args.file) module-attribute
video_urls = [] module-attribute
check_conda_env()
Source code in src/tnh_scholar/video_processing/yt_transcribe.py
31
32
33
34
35
36
37
38
39
def check_conda_env():
    active_env = os.environ.get("CONDA_DEFAULT_ENV")
    if active_env != EXPECTED_ENV:
        logger.warning(
            f"WARNING: The active conda environment is '{active_env}', but '{EXPECTED_ENV}' is required. "
            "Please activate the correct environment."
        )
        # Optionally exit the script
        sys.exit(1)
transcribe_youtube_videos(urls, output_base_dir, max_chunk_duration=DEFAULT_CHUNK_DURATION_S, start=None, translate=False)

Full pipeline for transcribing a list of YouTube videos.

Parameters:

Name Type Description Default
urls list[str]

List of YouTube video URLs.

required
output_base_dir Path

Base directory for storing output.

required
max_chunk_duration int

Maximum duration for audio chunks in seconds (default is 10 minutes).

DEFAULT_CHUNK_DURATION_S
Source code in src/tnh_scholar/video_processing/yt_transcribe.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def transcribe_youtube_videos(
    urls: list[str],
    output_base_dir: Path,
    max_chunk_duration: int = DEFAULT_CHUNK_DURATION_S,
    start: str = None,
    translate=False,
):
    """
    Full pipeline for transcribing a list of YouTube videos.

    Args:
        urls (list[str]): List of YouTube video URLs.
        output_base_dir (Path): Base directory for storing output.
        max_chunk_duration (int): Maximum duration for audio chunks in seconds (default is 10 minutes).
    """
    output_base_dir.mkdir(parents=True, exist_ok=True)

    for url in urls:
        try:
            logger.info(f"Processing video: {url}")

            # Step 1: Download audio
            logger.info("Downloading audio...")
            tmp_audio_file = download_audio_yt(url, output_base_dir, start_time=start)
            logger.info(f"Downloaded audio file: {tmp_audio_file}")

            # Prepare directories for chunks and outputs
            video_name = (
                tmp_audio_file.stem
            )  # Use the stem of the audio file (title without extension)
            video_output_dir = output_base_dir / video_name
            chunks_dir = video_output_dir / "chunks"
            chunks_dir.mkdir(parents=True, exist_ok=True)

            # Create the video directory and move the audio file into it
            video_output_dir.mkdir(parents=True, exist_ok=True)
            audio_file = video_output_dir / tmp_audio_file.name

            try:
                tmp_audio_file.rename(
                    audio_file
                )  # Move the audio file to the video directory
                logger.info(f"Moved audio file to: {audio_file}")
            except Exception as e:
                logger.error(f"Failed to move audio file to {video_output_dir}: {e}")
                # Ensure the code gracefully handles issues here, reassigning to the original tmp path.
                audio_file = tmp_audio_file

            # Step 2: Detect boundaries
            logger.info("Detecting boundaries...")
            boundaries = detect_boundaries(audio_file)
            logger.info("Boundaries generated.")

            # Step 3: Split audio into chunks
            logger.info("Splitting audio into chunks...")
            split_audio_at_boundaries(
                audio_file=audio_file,
                boundaries=boundaries,
                output_dir=chunks_dir,
                max_duration=max_chunk_duration,
            )
            logger.info(f"Audio chunks saved to: {chunks_dir}")

            # Step 4: Transcribe audio chunks
            logger.info("Transcribing audio chunks...")
            transcript_file = video_output_dir / f"{video_name}.txt"
            jsonl_file = video_output_dir / f"{video_name}.jsonl"
            process_audio_chunks(
                directory=chunks_dir,
                output_file=transcript_file,
                jsonl_file=jsonl_file,
                prompt=DEFAULT_PROMPT,
                translate=translate,
            )
            logger.info(f"Transcription completed for {url}")
            logger.info(f"Transcript saved to: {transcript_file}")
            logger.info(f"Raw transcription data saved to: {jsonl_file}")

        except Exception as e:
            logger.error(f"Failed to process video {url}: {e}")

xml_processing

FormattingError

Bases: Exception

Custom exception raised for formatting-related errors.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
 7
 8
 9
10
11
12
13
class FormattingError(Exception):
    """
    Custom exception raised for formatting-related errors.
    """

    def __init__(self, message="An error occurred due to invalid formatting."):
        super().__init__(message)
__init__(message='An error occurred due to invalid formatting.')
Source code in src/tnh_scholar/xml_processing/xml_processing.py
12
13
def __init__(self, message="An error occurred due to invalid formatting."):
    super().__init__(message)

PagebreakXMLParser

Parses XML documents split by tags, with optional grouping and tag retention.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
class PagebreakXMLParser:
    """
    Parses XML documents split by <pagebreak> tags, with optional grouping and tag retention.
    """

    def __init__(self, text: str):
        if not text or not text.strip():
            raise ValueError("Input XML text is empty or whitespace.")
        self.original_text = text
        self.cleaned_text = ""
        self.pages: List[str] = []
        self.pagebreak_tags: List[str] = []
        self._xml_decl_pattern = re.compile(r"^\s*<\?xml[^>]*\?>\s*", re.IGNORECASE)
        self._document_open_pattern = re.compile(r"^\s*<document>\s*", re.IGNORECASE)
        self._document_close_pattern = re.compile(r"\s*</document>\s*$", re.IGNORECASE)
        self._pagebreak_pattern = re.compile(r"^\s*<pagebreak\b[^>]*/>\s*$", re.IGNORECASE | re.MULTILINE)

    def _remove_preamble_and_document_tags(self):
        text = self._xml_decl_pattern.sub("", self.original_text, count=1)
        text = self._document_open_pattern.sub("", text, count=1)
        text = self._document_close_pattern.sub("", text, count=1)
        if not text.strip():
            raise ValueError("No content found between <document> tags.")
        self.cleaned_text = text

    def _split_on_pagebreaks(self):
        self.pages = []
        self.pagebreak_tags = re.findall(self._pagebreak_pattern, self.cleaned_text)
        split_lines = re.split(self._pagebreak_pattern, self.cleaned_text)
        for i, page_content in enumerate(split_lines):
            page_content = page_content.strip()
            # skip trailing empty after last pagebreak
            if not page_content and (i >= len(self.pagebreak_tags)):
                continue
            self.pages.append(page_content)

    def _attach_pagebreaks(self, keep_pagebreaks: bool):
        if not keep_pagebreaks:
            return
        for i in range(min(len(self.pages), len(self.pagebreak_tags))):
            if self.pages[i]:
                self.pages[i] = f"{self.pages[i]}\n{self.pagebreak_tags[i].strip()}"
            else:
                self.pages[i] = self.pagebreak_tags[i].strip()

    def _group_pages(self, page_groups: List[Tuple[int, int]]) -> List[str]:
        grouped_pages: List[str] = []
        for start, end in page_groups:
            if start < 1 or end < start:
                continue  # skip invalid groups
            if group := [
                self.pages[i]
                for i in range(start - 1, end)
                if 0 <= i < len(self.pages)
            ]:
                grouped_pages.append("\n".join(group).strip())
        return grouped_pages

    def parse(
        self,
        page_groups: Optional[List[Tuple[int, int]]] = None,
        keep_pagebreaks: bool = True,
    ) -> List[str]:
        """
        Parses the XML and returns a list of page contents, optionally grouped and with pagebreaks retained.
        """
        self._remove_preamble_and_document_tags()
        self._split_on_pagebreaks()
        self._attach_pagebreaks(keep_pagebreaks)
        # Remove empty pages
        self.pages = [p for p in self.pages if p]
        if not self.pages:
            raise ValueError("No pages found in the XML content after splitting on <pagebreak> tags.")
        return self._group_pages(page_groups) if page_groups else self.pages
cleaned_text = '' instance-attribute
original_text = text instance-attribute
pagebreak_tags = [] instance-attribute
pages = [] instance-attribute
__init__(text)
Source code in src/tnh_scholar/xml_processing/xml_processing.py
146
147
148
149
150
151
152
153
154
155
156
def __init__(self, text: str):
    if not text or not text.strip():
        raise ValueError("Input XML text is empty or whitespace.")
    self.original_text = text
    self.cleaned_text = ""
    self.pages: List[str] = []
    self.pagebreak_tags: List[str] = []
    self._xml_decl_pattern = re.compile(r"^\s*<\?xml[^>]*\?>\s*", re.IGNORECASE)
    self._document_open_pattern = re.compile(r"^\s*<document>\s*", re.IGNORECASE)
    self._document_close_pattern = re.compile(r"\s*</document>\s*$", re.IGNORECASE)
    self._pagebreak_pattern = re.compile(r"^\s*<pagebreak\b[^>]*/>\s*$", re.IGNORECASE | re.MULTILINE)
parse(page_groups=None, keep_pagebreaks=True)

Parses the XML and returns a list of page contents, optionally grouped and with pagebreaks retained.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def parse(
    self,
    page_groups: Optional[List[Tuple[int, int]]] = None,
    keep_pagebreaks: bool = True,
) -> List[str]:
    """
    Parses the XML and returns a list of page contents, optionally grouped and with pagebreaks retained.
    """
    self._remove_preamble_and_document_tags()
    self._split_on_pagebreaks()
    self._attach_pagebreaks(keep_pagebreaks)
    # Remove empty pages
    self.pages = [p for p in self.pages if p]
    if not self.pages:
        raise ValueError("No pages found in the XML content after splitting on <pagebreak> tags.")
    return self._group_pages(page_groups) if page_groups else self.pages

join_xml_data_to_doc(file_path, data, overwrite=False)

Joins a list of XML-tagged data with newlines, wraps it with tags, and writes it to the specified file. Raises an exception if the file exists and overwrite is not set.

Parameters:

Name Type Description Default
file_path Path

Path to the output file.

required
data List[str]

List of XML-tagged data strings.

required
overwrite bool

Whether to overwrite the file if it exists.

False

Raises:

Type Description
FileExistsError

If the file exists and overwrite is False.

ValueError

If the data list is empty.

Example

join_xml_data_to_doc(Path("output.xml"), ["Data"], overwrite=True)

Source code in src/tnh_scholar/xml_processing/xml_processing.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def join_xml_data_to_doc(
    file_path: Path, data: List[str], overwrite: bool = False
) -> None:
    """
    Joins a list of XML-tagged data with newlines, wraps it with <document> tags,
    and writes it to the specified file. Raises an exception if the file exists
    and overwrite is not set.

    Args:
        file_path (Path): Path to the output file.
        data (List[str]): List of XML-tagged data strings.
        overwrite (bool): Whether to overwrite the file if it exists.

    Raises:
        FileExistsError: If the file exists and overwrite is False.
        ValueError: If the data list is empty.

    Example:
        >>> join_xml_data_to_doc(Path("output.xml"), ["<tag>Data</tag>"], overwrite=True)
    """
    if file_path.exists() and not overwrite:
        raise FileExistsError(
            f"The file {file_path} already exists and overwrite is not set."
        )

    if not data:
        raise ValueError("The data list cannot be empty.")

    # Create the XML content
    joined_data = "\n".join(data)  # Joining data with newline
    xml_content = f"<document>\n{joined_data}\n</document>"

    # Write to file
    file_path.write_text(xml_content, encoding="utf-8")

remove_page_tags(text)

Removes and tags from a text string.

Parameters: - text (str): The input text containing tags.

Returns: - str: The text with tags removed.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def remove_page_tags(text):
    """
    Removes <page ...> and </page> tags from a text string.

    Parameters:
    - text (str): The input text containing <page> tags.

    Returns:
    - str: The text with <page> tags removed.
    """
    # Remove opening <page ...> tags
    text = re.sub(r"<page[^>]*>", "", text)
    # Remove closing </page> tags
    text = re.sub(r"</page>", "", text)
    return text

save_pages_to_xml(output_xml_path, text_pages, overwrite=False)

Generates and saves an XML file containing text pages, with a tag indicating the page ends.

Parameters:

Name Type Description Default
output_xml_path Path

The Path object for the file where the XML file will be saved.

required
text_pages List[str]

A list of strings, each representing the text content of a page.

required
overwrite bool

If True, overwrites the file if it exists. Default is False.

False

Returns:

Type Description
None

None

Raises:

Type Description
ValueError

If the input list of text_pages is empty or contains invalid types.

FileExistsError

If the file already exists and overwrite is False.

PermissionError

If the file cannot be created due to insufficient permissions.

OSError

For other file I/O-related errors.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def save_pages_to_xml(
    output_xml_path: Path,
    text_pages: List[str],
    overwrite: bool = False,
) -> None:
    """
    Generates and saves an XML file containing text pages, with a <pagebreak> tag indicating the page ends.

    Parameters:
        output_xml_path (Path): The Path object for the file where the XML file will be saved.
        text_pages (List[str]): A list of strings, each representing the text content of a page.
        overwrite (bool): If True, overwrites the file if it exists. Default is False.

    Returns:
        None

    Raises:
        ValueError: If the input list of text_pages is empty or contains invalid types.
        FileExistsError: If the file already exists and overwrite is False.
        PermissionError: If the file cannot be created due to insufficient permissions.
        OSError: For other file I/O-related errors.
    """
    if not text_pages:
        raise ValueError("The text_pages list is empty. Cannot generate XML.")

    # Check if the file exists and handle overwrite behavior
    if output_xml_path.exists() and not overwrite:
        raise FileExistsError(
            f"The file '{output_xml_path}' already exists. Set overwrite=True to overwrite."
        )

    try:
        # Ensure the output directory exists
        output_xml_path.parent.mkdir(parents=True, exist_ok=True)

        # Write the XML file
        with output_xml_path.open("w", encoding="utf-8") as xml_file:
            # Write XML declaration and root element
            xml_file.write("<?xml version='1.0' encoding='UTF-8'?>\n")
            xml_file.write("<document>\n")

            # Add each page with its content and <pagebreak> tag
            for page_number, text in enumerate(text_pages, start=1):
                if not isinstance(text, str):
                    raise ValueError(
                        f"Invalid page content at index {page_number - 1}: expected a string."
                    )

                content = text.strip()
                escaped_text = escape(content)
                xml_file.write(f"    {escaped_text}\n")
                xml_file.write(f"    <pagebreak page='{page_number}' />\n")

            # Close the root element
            xml_file.write("</document>\n")

        print(f"XML file successfully saved at {output_xml_path}")

    except PermissionError as e:
        raise PermissionError(
            f"Permission denied while writing to {output_xml_path}: {e}"
        ) from e

    except OSError as e:
        raise OSError(
            f"An OS-related error occurred while saving XML file at {output_xml_path}: {e}"
        ) from e

    except Exception as e:
        raise RuntimeError(f"An unexpected error occurred: {e}") from e

split_xml_on_pagebreaks(text, page_groups=None, keep_pagebreaks=True)

Splits an XML document into individual pages based on tags. Optionally groups pages together based on page_groups and retains tags if keep_pagebreaks is True.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
217
218
219
220
221
222
223
224
225
226
227
228
def split_xml_on_pagebreaks(
    text: str,
    page_groups: Optional[List[Tuple[int, int]]] = None,
    keep_pagebreaks: bool = True,
) -> List[str]:
    """
    Splits an XML document into individual pages based on <pagebreak> tags.
    Optionally groups pages together based on page_groups
    and retains <pagebreak> tags if keep_pagebreaks is True.
    """
    parser = PagebreakXMLParser(text)
    return parser.parse(page_groups=page_groups, keep_pagebreaks=keep_pagebreaks)

split_xml_pages(text)

Backwards-compatible helper that returns the page contents without pagebreak tags.

Parameters:

Name Type Description Default
text str

XML document string.

required

Returns:

Type Description
List[str]

List of page strings.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
231
232
233
234
235
236
237
238
239
240
241
def split_xml_pages(text: str) -> List[str]:
    """
    Backwards-compatible helper that returns the page contents without pagebreak tags.

    Args:
        text: XML document string.

    Returns:
        List of page strings.
    """
    return split_xml_on_pagebreaks(text, keep_pagebreaks=False)

extract_tags

extract_unique_tags(xml_file)

Extract all unique tags from an XML file using lxml.

Parameters:

Name Type Description Default
xml_file str

Path to the XML file.

required

Returns:

Name Type Description
set

A set of unique tags in the XML document.

Source code in src/tnh_scholar/xml_processing/extract_tags.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
def extract_unique_tags(xml_file):
    """
    Extract all unique tags from an XML file using lxml.

    Parameters:
        xml_file (str): Path to the XML file.

    Returns:
        set: A set of unique tags in the XML document.
    """
    # Parse the XML file
    tree = etree.parse(xml_file)

    # Find all unique tags and return
    return {element.tag for element in tree.iter()}
main()
Source code in src/tnh_scholar/xml_processing/extract_tags.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def main():
    # Create argument parser
    parser = argparse.ArgumentParser(
        description="Extract all unique tags from an XML file."
    )
    parser.add_argument("xml_file", type=str, help="Path to the XML file.")

    # Parse command-line arguments
    args = parser.parse_args()

    # Extract tags
    tags = extract_unique_tags(args.xml_file)

    # Print results
    print("Unique Tags Found:")
    for tag in sorted(tags):
        print(tag)

xml_processing

FormattingError

Bases: Exception

Custom exception raised for formatting-related errors.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
 7
 8
 9
10
11
12
13
class FormattingError(Exception):
    """
    Custom exception raised for formatting-related errors.
    """

    def __init__(self, message="An error occurred due to invalid formatting."):
        super().__init__(message)
__init__(message='An error occurred due to invalid formatting.')
Source code in src/tnh_scholar/xml_processing/xml_processing.py
12
13
def __init__(self, message="An error occurred due to invalid formatting."):
    super().__init__(message)
PagebreakXMLParser

Parses XML documents split by tags, with optional grouping and tag retention.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
class PagebreakXMLParser:
    """
    Parses XML documents split by <pagebreak> tags, with optional grouping and tag retention.
    """

    def __init__(self, text: str):
        if not text or not text.strip():
            raise ValueError("Input XML text is empty or whitespace.")
        self.original_text = text
        self.cleaned_text = ""
        self.pages: List[str] = []
        self.pagebreak_tags: List[str] = []
        self._xml_decl_pattern = re.compile(r"^\s*<\?xml[^>]*\?>\s*", re.IGNORECASE)
        self._document_open_pattern = re.compile(r"^\s*<document>\s*", re.IGNORECASE)
        self._document_close_pattern = re.compile(r"\s*</document>\s*$", re.IGNORECASE)
        self._pagebreak_pattern = re.compile(r"^\s*<pagebreak\b[^>]*/>\s*$", re.IGNORECASE | re.MULTILINE)

    def _remove_preamble_and_document_tags(self):
        text = self._xml_decl_pattern.sub("", self.original_text, count=1)
        text = self._document_open_pattern.sub("", text, count=1)
        text = self._document_close_pattern.sub("", text, count=1)
        if not text.strip():
            raise ValueError("No content found between <document> tags.")
        self.cleaned_text = text

    def _split_on_pagebreaks(self):
        self.pages = []
        self.pagebreak_tags = re.findall(self._pagebreak_pattern, self.cleaned_text)
        split_lines = re.split(self._pagebreak_pattern, self.cleaned_text)
        for i, page_content in enumerate(split_lines):
            page_content = page_content.strip()
            # skip trailing empty after last pagebreak
            if not page_content and (i >= len(self.pagebreak_tags)):
                continue
            self.pages.append(page_content)

    def _attach_pagebreaks(self, keep_pagebreaks: bool):
        if not keep_pagebreaks:
            return
        for i in range(min(len(self.pages), len(self.pagebreak_tags))):
            if self.pages[i]:
                self.pages[i] = f"{self.pages[i]}\n{self.pagebreak_tags[i].strip()}"
            else:
                self.pages[i] = self.pagebreak_tags[i].strip()

    def _group_pages(self, page_groups: List[Tuple[int, int]]) -> List[str]:
        grouped_pages: List[str] = []
        for start, end in page_groups:
            if start < 1 or end < start:
                continue  # skip invalid groups
            if group := [
                self.pages[i]
                for i in range(start - 1, end)
                if 0 <= i < len(self.pages)
            ]:
                grouped_pages.append("\n".join(group).strip())
        return grouped_pages

    def parse(
        self,
        page_groups: Optional[List[Tuple[int, int]]] = None,
        keep_pagebreaks: bool = True,
    ) -> List[str]:
        """
        Parses the XML and returns a list of page contents, optionally grouped and with pagebreaks retained.
        """
        self._remove_preamble_and_document_tags()
        self._split_on_pagebreaks()
        self._attach_pagebreaks(keep_pagebreaks)
        # Remove empty pages
        self.pages = [p for p in self.pages if p]
        if not self.pages:
            raise ValueError("No pages found in the XML content after splitting on <pagebreak> tags.")
        return self._group_pages(page_groups) if page_groups else self.pages
cleaned_text = '' instance-attribute
original_text = text instance-attribute
pagebreak_tags = [] instance-attribute
pages = [] instance-attribute
__init__(text)
Source code in src/tnh_scholar/xml_processing/xml_processing.py
146
147
148
149
150
151
152
153
154
155
156
def __init__(self, text: str):
    if not text or not text.strip():
        raise ValueError("Input XML text is empty or whitespace.")
    self.original_text = text
    self.cleaned_text = ""
    self.pages: List[str] = []
    self.pagebreak_tags: List[str] = []
    self._xml_decl_pattern = re.compile(r"^\s*<\?xml[^>]*\?>\s*", re.IGNORECASE)
    self._document_open_pattern = re.compile(r"^\s*<document>\s*", re.IGNORECASE)
    self._document_close_pattern = re.compile(r"\s*</document>\s*$", re.IGNORECASE)
    self._pagebreak_pattern = re.compile(r"^\s*<pagebreak\b[^>]*/>\s*$", re.IGNORECASE | re.MULTILINE)
parse(page_groups=None, keep_pagebreaks=True)

Parses the XML and returns a list of page contents, optionally grouped and with pagebreaks retained.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def parse(
    self,
    page_groups: Optional[List[Tuple[int, int]]] = None,
    keep_pagebreaks: bool = True,
) -> List[str]:
    """
    Parses the XML and returns a list of page contents, optionally grouped and with pagebreaks retained.
    """
    self._remove_preamble_and_document_tags()
    self._split_on_pagebreaks()
    self._attach_pagebreaks(keep_pagebreaks)
    # Remove empty pages
    self.pages = [p for p in self.pages if p]
    if not self.pages:
        raise ValueError("No pages found in the XML content after splitting on <pagebreak> tags.")
    return self._group_pages(page_groups) if page_groups else self.pages
join_xml_data_to_doc(file_path, data, overwrite=False)

Joins a list of XML-tagged data with newlines, wraps it with tags, and writes it to the specified file. Raises an exception if the file exists and overwrite is not set.

Parameters:

Name Type Description Default
file_path Path

Path to the output file.

required
data List[str]

List of XML-tagged data strings.

required
overwrite bool

Whether to overwrite the file if it exists.

False

Raises:

Type Description
FileExistsError

If the file exists and overwrite is False.

ValueError

If the data list is empty.

Example

join_xml_data_to_doc(Path("output.xml"), ["Data"], overwrite=True)

Source code in src/tnh_scholar/xml_processing/xml_processing.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def join_xml_data_to_doc(
    file_path: Path, data: List[str], overwrite: bool = False
) -> None:
    """
    Joins a list of XML-tagged data with newlines, wraps it with <document> tags,
    and writes it to the specified file. Raises an exception if the file exists
    and overwrite is not set.

    Args:
        file_path (Path): Path to the output file.
        data (List[str]): List of XML-tagged data strings.
        overwrite (bool): Whether to overwrite the file if it exists.

    Raises:
        FileExistsError: If the file exists and overwrite is False.
        ValueError: If the data list is empty.

    Example:
        >>> join_xml_data_to_doc(Path("output.xml"), ["<tag>Data</tag>"], overwrite=True)
    """
    if file_path.exists() and not overwrite:
        raise FileExistsError(
            f"The file {file_path} already exists and overwrite is not set."
        )

    if not data:
        raise ValueError("The data list cannot be empty.")

    # Create the XML content
    joined_data = "\n".join(data)  # Joining data with newline
    xml_content = f"<document>\n{joined_data}\n</document>"

    # Write to file
    file_path.write_text(xml_content, encoding="utf-8")
remove_page_tags(text)

Removes and tags from a text string.

Parameters: - text (str): The input text containing tags.

Returns: - str: The text with tags removed.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def remove_page_tags(text):
    """
    Removes <page ...> and </page> tags from a text string.

    Parameters:
    - text (str): The input text containing <page> tags.

    Returns:
    - str: The text with <page> tags removed.
    """
    # Remove opening <page ...> tags
    text = re.sub(r"<page[^>]*>", "", text)
    # Remove closing </page> tags
    text = re.sub(r"</page>", "", text)
    return text
save_pages_to_xml(output_xml_path, text_pages, overwrite=False)

Generates and saves an XML file containing text pages, with a tag indicating the page ends.

Parameters:

Name Type Description Default
output_xml_path Path

The Path object for the file where the XML file will be saved.

required
text_pages List[str]

A list of strings, each representing the text content of a page.

required
overwrite bool

If True, overwrites the file if it exists. Default is False.

False

Returns:

Type Description
None

None

Raises:

Type Description
ValueError

If the input list of text_pages is empty or contains invalid types.

FileExistsError

If the file already exists and overwrite is False.

PermissionError

If the file cannot be created due to insufficient permissions.

OSError

For other file I/O-related errors.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def save_pages_to_xml(
    output_xml_path: Path,
    text_pages: List[str],
    overwrite: bool = False,
) -> None:
    """
    Generates and saves an XML file containing text pages, with a <pagebreak> tag indicating the page ends.

    Parameters:
        output_xml_path (Path): The Path object for the file where the XML file will be saved.
        text_pages (List[str]): A list of strings, each representing the text content of a page.
        overwrite (bool): If True, overwrites the file if it exists. Default is False.

    Returns:
        None

    Raises:
        ValueError: If the input list of text_pages is empty or contains invalid types.
        FileExistsError: If the file already exists and overwrite is False.
        PermissionError: If the file cannot be created due to insufficient permissions.
        OSError: For other file I/O-related errors.
    """
    if not text_pages:
        raise ValueError("The text_pages list is empty. Cannot generate XML.")

    # Check if the file exists and handle overwrite behavior
    if output_xml_path.exists() and not overwrite:
        raise FileExistsError(
            f"The file '{output_xml_path}' already exists. Set overwrite=True to overwrite."
        )

    try:
        # Ensure the output directory exists
        output_xml_path.parent.mkdir(parents=True, exist_ok=True)

        # Write the XML file
        with output_xml_path.open("w", encoding="utf-8") as xml_file:
            # Write XML declaration and root element
            xml_file.write("<?xml version='1.0' encoding='UTF-8'?>\n")
            xml_file.write("<document>\n")

            # Add each page with its content and <pagebreak> tag
            for page_number, text in enumerate(text_pages, start=1):
                if not isinstance(text, str):
                    raise ValueError(
                        f"Invalid page content at index {page_number - 1}: expected a string."
                    )

                content = text.strip()
                escaped_text = escape(content)
                xml_file.write(f"    {escaped_text}\n")
                xml_file.write(f"    <pagebreak page='{page_number}' />\n")

            # Close the root element
            xml_file.write("</document>\n")

        print(f"XML file successfully saved at {output_xml_path}")

    except PermissionError as e:
        raise PermissionError(
            f"Permission denied while writing to {output_xml_path}: {e}"
        ) from e

    except OSError as e:
        raise OSError(
            f"An OS-related error occurred while saving XML file at {output_xml_path}: {e}"
        ) from e

    except Exception as e:
        raise RuntimeError(f"An unexpected error occurred: {e}") from e
split_xml_on_pagebreaks(text, page_groups=None, keep_pagebreaks=True)

Splits an XML document into individual pages based on tags. Optionally groups pages together based on page_groups and retains tags if keep_pagebreaks is True.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
217
218
219
220
221
222
223
224
225
226
227
228
def split_xml_on_pagebreaks(
    text: str,
    page_groups: Optional[List[Tuple[int, int]]] = None,
    keep_pagebreaks: bool = True,
) -> List[str]:
    """
    Splits an XML document into individual pages based on <pagebreak> tags.
    Optionally groups pages together based on page_groups
    and retains <pagebreak> tags if keep_pagebreaks is True.
    """
    parser = PagebreakXMLParser(text)
    return parser.parse(page_groups=page_groups, keep_pagebreaks=keep_pagebreaks)
split_xml_pages(text)

Backwards-compatible helper that returns the page contents without pagebreak tags.

Parameters:

Name Type Description Default
text str

XML document string.

required

Returns:

Type Description
List[str]

List of page strings.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
231
232
233
234
235
236
237
238
239
240
241
def split_xml_pages(text: str) -> List[str]:
    """
    Backwards-compatible helper that returns the page contents without pagebreak tags.

    Args:
        text: XML document string.

    Returns:
        List of page strings.
    """
    return split_xml_on_pagebreaks(text, keep_pagebreaks=False)